Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import java.util.AbstractList;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -46,7 +45,6 @@
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.CharType;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.TimeWithLocalTzType;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.logicaltypes.DateTime;
import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
Expand Down Expand Up @@ -95,6 +93,7 @@
import org.checkerframework.checker.nullness.qual.Nullable;
import org.codehaus.commons.compiler.CompileException;
import org.codehaus.janino.ScriptEvaluator;
import org.joda.time.DateTime;
import org.joda.time.Instant;
import org.joda.time.ReadableInstant;

Expand Down Expand Up @@ -375,7 +374,7 @@ private static Expression castOutputTime(Expression value, FieldType toType) {
// Convert TIME to LocalTime
if (value.getType() == java.sql.Time.class) {
valueDateTime = Expressions.call(BuiltInMethod.TIME_TO_INT.method, valueDateTime);
} else if (value.getType() == Long.class) {
} else if (value.getType() == Integer.class || value.getType() == Long.class) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to include Integer now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The internal type is actually suppose to be int. We've been passing Calcite a Long instead and the compiler just happily upconverts when doing math between the two types. We are still doing math somewhere (window functions?) that turns this into a Long sometimes. I think even before this change we were at risk of receiving a Integer here. (In the next CL I'm switching this to the java Number interface so we are permissive on outputs.)
https://github.com/apache/calcite-avatica/blob/89e0deb510311b85b8c8bacde6d2ff70c309930e/core/src/main/java/org/apache/calcite/avatica/SqlType.java#L306

valueDateTime = Expressions.unbox(valueDateTime);
}
valueDateTime =
Expand All @@ -386,7 +385,7 @@ private static Expression castOutputTime(Expression value, FieldType toType) {
// Convert DATE to LocalDate
if (value.getType() == java.sql.Date.class) {
valueDateTime = Expressions.call(BuiltInMethod.DATE_TO_INT.method, valueDateTime);
} else if (value.getType() == Long.class) {
} else if (value.getType() == Integer.class || value.getType() == Long.class) {
valueDateTime = Expressions.unbox(valueDateTime);
}
valueDateTime = Expressions.call(LocalDate.class, "ofEpochDay", valueDateTime);
Expand Down Expand Up @@ -419,33 +418,6 @@ private static Expression castOutputTime(Expression value, FieldType toType) {
}

private static class InputGetterImpl implements RexToLixTranslator.InputGetter {
private static final Map<TypeName, Class> TYPE_CONVERSION_MAP =
ImmutableMap.<TypeName, Class>builder()
.put(TypeName.BYTE, Byte.class)
.put(TypeName.BYTES, byte[].class)
.put(TypeName.INT16, Short.class)
.put(TypeName.INT32, Integer.class)
.put(TypeName.INT64, Long.class)
.put(TypeName.DECIMAL, BigDecimal.class)
.put(TypeName.FLOAT, Float.class)
.put(TypeName.DOUBLE, Double.class)
.put(TypeName.STRING, String.class)
.put(TypeName.DATETIME, ReadableInstant.class)
.put(TypeName.BOOLEAN, Boolean.class)
.put(TypeName.MAP, Map.class)
.put(TypeName.ARRAY, Collection.class)
.put(TypeName.ITERABLE, Iterable.class)
.put(TypeName.ROW, Row.class)
.build();

private static final Map<String, Class> LOGICAL_TYPE_TO_BASE_TYPE_MAP =
ImmutableMap.<String, Class>builder()
.put(SqlTypes.DATE.getIdentifier(), Long.class)
.put(SqlTypes.TIME.getIdentifier(), Long.class)
.put(TimeWithLocalTzType.IDENTIFIER, ReadableInstant.class)
.put(SqlTypes.DATETIME.getIdentifier(), Row.class)
.put(CharType.IDENTIFIER, String.class)
.build();

private final Expression input;
private final Schema inputSchema;
Expand All @@ -457,84 +429,194 @@ private InputGetterImpl(Expression input, Schema inputSchema) {

@Override
public Expression field(BlockBuilder list, int index, Type storageType) {
return value(list, index, storageType, input, inputSchema);
return getBeamField(list, index, storageType, input, inputSchema);
}

private static Expression value(
// Read field from Beam Row
private static Expression getBeamField(
BlockBuilder list, int index, Type storageType, Expression input, Schema schema) {
if (index >= schema.getFieldCount() || index < 0) {
throw new IllegalArgumentException("Unable to find value #" + index);
}

final Expression expression = list.append(list.newName("current"), input);

FieldType fromType = schema.getField(index).getType();
Class convertTo = null;
if (storageType == Object.class) {
convertTo = Object.class;
} else if (fromType.getTypeName().isLogicalType()) {
convertTo = LOGICAL_TYPE_TO_BASE_TYPE_MAP.get(fromType.getLogicalType().getIdentifier());
} else {
convertTo = TYPE_CONVERSION_MAP.get(fromType.getTypeName());
}
if (convertTo == null) {
throw new UnsupportedOperationException("Unable to get " + fromType.getTypeName());
FieldType fieldType = schema.getField(index).getType();
Expression value;
switch (fieldType.getTypeName()) {
case BYTE:
value = Expressions.call(expression, "getByte", Expressions.constant(index));
break;
case INT16:
value = Expressions.call(expression, "getInt16", Expressions.constant(index));
break;
case INT32:
value = Expressions.call(expression, "getInt32", Expressions.constant(index));
break;
case INT64:
value = Expressions.call(expression, "getInt64", Expressions.constant(index));
break;
case DECIMAL:
value = Expressions.call(expression, "getDecimal", Expressions.constant(index));
break;
case FLOAT:
value = Expressions.call(expression, "getFloat", Expressions.constant(index));
break;
case DOUBLE:
value = Expressions.call(expression, "getDouble", Expressions.constant(index));
break;
case STRING:
value = Expressions.call(expression, "getString", Expressions.constant(index));
break;
case DATETIME:
value = Expressions.call(expression, "getDateTime", Expressions.constant(index));
break;
case BOOLEAN:
value = Expressions.call(expression, "getBoolean", Expressions.constant(index));
break;
case BYTES:
value = Expressions.call(expression, "getBytes", Expressions.constant(index));
break;
case ARRAY:
value = Expressions.call(expression, "getArray", Expressions.constant(index));
if (storageType == Object.class
&& TypeName.ROW.equals(fieldType.getCollectionElementType().getTypeName())) {
// Workaround for missing row output support
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the one special case that remains, it requires rewriting the output code to support rows. That is the next PR.

return Expressions.convert_(value, Object.class);
}
break;
case MAP:
value = Expressions.call(expression, "getMap", Expressions.constant(index));
break;
case ROW:
value = Expressions.call(expression, "getRow", Expressions.constant(index));
break;
case LOGICAL_TYPE:
String identifier = fieldType.getLogicalType().getIdentifier();
if (CharType.IDENTIFIER.equals(identifier)) {
value = Expressions.call(expression, "getString", Expressions.constant(index));
} else if (TimeWithLocalTzType.IDENTIFIER.equals(identifier)) {
value = Expressions.call(expression, "getDateTime", Expressions.constant(index));
} else if (SqlTypes.DATE.getIdentifier().equals(identifier)) {
value =
Expressions.convert_(
Expressions.call(
expression,
"getLogicalTypeValue",
Expressions.constant(index),
Expressions.constant(LocalDate.class)),
LocalDate.class);
} else if (SqlTypes.TIME.getIdentifier().equals(identifier)) {
value =
Expressions.convert_(
Expressions.call(
expression,
"getLogicalTypeValue",
Expressions.constant(index),
Expressions.constant(LocalTime.class)),
LocalTime.class);
} else if (SqlTypes.DATETIME.getIdentifier().equals(identifier)) {
value =
Expressions.convert_(
Expressions.call(
expression,
"getLogicalTypeValue",
Expressions.constant(index),
Expressions.constant(LocalDateTime.class)),
LocalDateTime.class);
} else {
throw new UnsupportedOperationException("Unable to get logical type " + identifier);
}
break;
default:
throw new UnsupportedOperationException("Unable to get " + fieldType.getTypeName());
}

Expression value =
Expressions.convert_(
Expressions.call(
expression,
"getBaseValue",
Expressions.constant(index),
Expressions.constant(convertTo)),
convertTo);
return (storageType != Object.class) ? value(value, fromType) : value;
return toCalciteValue(value, fieldType);
}

private static Expression value(Expression value, Schema.FieldType type) {
if (type.getTypeName().isLogicalType()) {
String logicalId = type.getLogicalType().getIdentifier();
if (SqlTypes.TIME.getIdentifier().equals(logicalId)) {
// Value conversion: Beam => Calcite
private static Expression toCalciteValue(Expression value, FieldType fieldType) {
switch (fieldType.getTypeName()) {
case BYTE:
return Expressions.convert_(value, Byte.class);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these conversions necessary? e.g. doesn't getByte already return Byte?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These may be coming out of a Map or List as an Object. This isn't strictly necessary, I'm doing it to ensure types are what we expect. It generates a cast which results in a ClassCastException if it isn't the type we expected.

case INT16:
return Expressions.convert_(value, Short.class);
case INT32:
return Expressions.convert_(value, Integer.class);
case INT64:
return Expressions.convert_(value, Long.class);
case DECIMAL:
return Expressions.convert_(value, BigDecimal.class);
case FLOAT:
return Expressions.convert_(value, Float.class);
case DOUBLE:
return Expressions.convert_(value, Double.class);
case STRING:
return Expressions.convert_(value, String.class);
case BOOLEAN:
return Expressions.convert_(value, Boolean.class);
case DATETIME:
return nullOr(
value, Expressions.divide(value, Expressions.constant(NANOS_PER_MILLISECOND)));
} else if (SqlTypes.DATE.getIdentifier().equals(logicalId)) {
return value;
} else if (SqlTypes.DATETIME.getIdentifier().equals(logicalId)) {
Expression dateValue =
Expressions.call(value, "getInt64", Expressions.constant(DateTime.DATE_FIELD_NAME));
Expression timeValue =
Expressions.call(value, "getInt64", Expressions.constant(DateTime.TIME_FIELD_NAME));
Expression returnValue =
Expressions.add(
Expressions.multiply(dateValue, Expressions.constant(MILLIS_PER_DAY)),
Expressions.divide(timeValue, Expressions.constant(NANOS_PER_MILLISECOND)));
return nullOr(value, returnValue);
} else if (!CharType.IDENTIFIER.equals(logicalId)) {
throw new UnsupportedOperationException(
"Unknown LogicalType " + type.getLogicalType().getIdentifier());
}
} else if (type.getTypeName().isMapType()) {
return nullOr(value, map(value, type.getMapValueType()));
} else if (CalciteUtils.isDateTimeType(type)) {
return nullOr(value, Expressions.call(value, "getMillis"));
} else if (type.getTypeName().isCompositeType()) {
return nullOr(value, row(value, type.getRowSchema()));
} else if (type.getTypeName().isCollectionType()) {
return nullOr(value, list(value, type.getCollectionElementType()));
} else if (type.getTypeName() == TypeName.BYTES) {
return nullOr(
value, Expressions.new_(ByteString.class, Types.castIfNecessary(byte[].class, value)));
value, Expressions.call(Expressions.convert_(value, DateTime.class), "getMillis"));
case BYTES:
return nullOr(
value, Expressions.new_(ByteString.class, Expressions.convert_(value, byte[].class)));
case ARRAY:
return nullOr(value, toCalciteList(value, fieldType.getCollectionElementType()));
case MAP:
return nullOr(value, toCalciteMap(value, fieldType.getMapValueType()));
case ROW:
return nullOr(value, toCalciteRow(value, fieldType.getRowSchema()));
case LOGICAL_TYPE:
String identifier = fieldType.getLogicalType().getIdentifier();
if (CharType.IDENTIFIER.equals(identifier)) {
return Expressions.convert_(value, String.class);
} else if (TimeWithLocalTzType.IDENTIFIER.equals(identifier)) {
return nullOr(
value, Expressions.call(Expressions.convert_(value, DateTime.class), "getMillis"));
} else if (SqlTypes.DATE.getIdentifier().equals(identifier)) {
return nullOr(
value,
Expressions.call(
Expressions.box(
Expressions.call(
Expressions.convert_(value, LocalDate.class), "toEpochDay")),
"intValue"));
} else if (SqlTypes.TIME.getIdentifier().equals(identifier)) {
return nullOr(
value,
Expressions.call(
Expressions.box(
Expressions.divide(
Expressions.call(
Expressions.convert_(value, LocalTime.class), "toNanoOfDay"),
Expressions.constant(NANOS_PER_MILLISECOND))),
"intValue"));
} else if (SqlTypes.DATETIME.getIdentifier().equals(identifier)) {
value = Expressions.convert_(value, LocalDateTime.class);
Expression dateValue =
Expressions.call(Expressions.call(value, "toLocalDate"), "toEpochDay");
Expression timeValue =
Expressions.call(Expressions.call(value, "toLocalTime"), "toNanoOfDay");
Expression returnValue =
Expressions.add(
Expressions.multiply(dateValue, Expressions.constant(MILLIS_PER_DAY)),
Expressions.divide(timeValue, Expressions.constant(NANOS_PER_MILLISECOND)));
return nullOr(value, returnValue);
} else {
throw new UnsupportedOperationException("Unable to convert logical type " + identifier);
}
default:
throw new UnsupportedOperationException("Unable to convert " + fieldType.getTypeName());
}

return value;
}

private static Expression list(Expression input, FieldType elementType) {
private static Expression toCalciteList(Expression input, FieldType elementType) {
ParameterExpression value = Expressions.parameter(Object.class);

BlockBuilder block = new BlockBuilder();
block.add(value(value, elementType));
block.add(toCalciteValue(value, elementType));

return Expressions.new_(
WrappedList.class,
Expand All @@ -548,11 +630,11 @@ private static Expression list(Expression input, FieldType elementType) {
block.toBlock())));
}

private static Expression map(Expression input, FieldType mapValueType) {
private static Expression toCalciteMap(Expression input, FieldType mapValueType) {
ParameterExpression value = Expressions.parameter(Object.class);

BlockBuilder block = new BlockBuilder();
block.add(value(value, mapValueType));
block.add(toCalciteValue(value, mapValueType));

return Expressions.new_(
WrappedMap.class,
Expand All @@ -566,14 +648,14 @@ private static Expression map(Expression input, FieldType mapValueType) {
block.toBlock())));
}

private static Expression row(Expression input, Schema schema) {
private static Expression toCalciteRow(Expression input, Schema schema) {
ParameterExpression row = Expressions.parameter(Row.class);
ParameterExpression index = Expressions.parameter(int.class);
BlockBuilder body = new BlockBuilder(/* optimizing= */ false);

for (int i = 0; i < schema.getFieldCount(); i++) {
BlockBuilder list = new BlockBuilder(/* optimizing= */ false, body);
Expression returnValue = value(list, i, /* storageType= */ null, row, schema);
Expression returnValue = getBeamField(list, i, /* storageType= */ null, row, schema);

list.append(returnValue);

Expand Down
Loading