diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java index feab7a457669..15bfe6400bb4 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java @@ -23,6 +23,7 @@ import java.io.InputStream; import java.io.OutputStream; import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Modifier; import java.util.BitSet; import java.util.List; import java.util.Map; @@ -32,7 +33,6 @@ import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.Schema.TypeName; import org.apache.beam.sdk.schemas.SchemaCoder; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.ByteBuddy; @@ -47,11 +47,9 @@ import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.FixedValue; import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.Implementation; import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.ByteCodeAppender; +import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.ByteCodeAppender.Size; import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.Duplication; import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.StackManipulation; -import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.StackManipulation.Compound; -import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.TypeCreation; -import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.collection.ArrayFactory; import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.member.FieldAccess; import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.member.MethodInvocation; import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.member.MethodReturn; @@ -99,68 +97,100 @@ @Experimental(Kind.SCHEMAS) public abstract class RowCoderGenerator { private static final ByteBuddy BYTE_BUDDY = new ByteBuddy(); - private static final ForLoadedType CODER_TYPE = new ForLoadedType(Coder.class); - private static final ForLoadedType LIST_CODER_TYPE = new ForLoadedType(ListCoder.class); - private static final ForLoadedType ITERABLE_CODER_TYPE = new ForLoadedType(IterableCoder.class); - private static final ForLoadedType MAP_CODER_TYPE = new ForLoadedType(MapCoder.class); private static final BitSetCoder NULL_LIST_CODER = BitSetCoder.of(); private static final VarIntCoder VAR_INT_CODER = VarIntCoder.of(); - private static final ForLoadedType NULLABLE_CODER = new ForLoadedType(NullableCoder.class); private static final String CODERS_FIELD_NAME = "FIELD_CODERS"; - // A map of primitive types -> StackManipulations to create their coders. - private static final Map CODER_MAP; - // Cache for Coder class that are already generated. - private static Map> generatedCoders = Maps.newConcurrentMap(); - - static { - // Initialize the CODER_MAP with the StackManipulations to create the primitive coders. - // Assumes that each class contains a static of() constructor method. - CODER_MAP = Maps.newHashMap(); - for (Map.Entry entry : SchemaCoder.CODER_MAP.entrySet()) { - StackManipulation stackManipulation = - MethodInvocation.invoke( - new ForLoadedType(entry.getValue().getClass()) - .getDeclaredMethods() - .filter(ElementMatchers.named("of")) - .getOnly()); - CODER_MAP.putIfAbsent(entry.getKey(), stackManipulation); - } - } + private static final Map> GENERATED_CODERS = Maps.newConcurrentMap(); @SuppressWarnings("unchecked") public static Coder generate(Schema schema) { // Using ConcurrentHashMap::computeIfAbsent here would deadlock in case of nested // coders. Using HashMap::computeIfAbsent generates ConcurrentModificationExceptions in Java 11. - Coder rowCoder = generatedCoders.get(schema.getUUID()); + Coder rowCoder = GENERATED_CODERS.get(schema.getUUID()); if (rowCoder == null) { TypeDescription.Generic coderType = TypeDescription.Generic.Builder.parameterizedType(Coder.class, Row.class).build(); DynamicType.Builder builder = (DynamicType.Builder) BYTE_BUDDY.subclass(coderType); - builder = createComponentCoders(schema, builder); builder = implementMethods(schema, builder); + + Coder[] componentCoders = new Coder[schema.getFieldCount()]; + for (int i = 0; i < schema.getFieldCount(); ++i) { + // We use withNullable(false) as nulls are handled by the RowCoder and the individual + // component coders therefore do not need to handle nulls. + componentCoders[i] = + SchemaCoder.coderForFieldType(schema.getField(i).getType().withNullable(false)); + } + + builder = + builder.defineField( + CODERS_FIELD_NAME, Coder[].class, Visibility.PRIVATE, FieldManifestation.FINAL); + + builder = + builder + .defineConstructor(Modifier.PUBLIC) + .withParameters(Coder[].class) + .intercept(new GeneratedCoderConstructor()); + try { rowCoder = builder .make() .load(Coder.class.getClassLoader(), ClassLoadingStrategy.Default.INJECTION) .getLoaded() - .getDeclaredConstructor() - .newInstance(); + .getDeclaredConstructor(Coder[].class) + .newInstance((Object) componentCoders); } catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) { - throw new RuntimeException("Unable to generate coder for schema " + schema); + throw new RuntimeException("Unable to generate coder for schema " + schema, e); } - generatedCoders.put(schema.getUUID(), rowCoder); + GENERATED_CODERS.put(schema.getUUID(), rowCoder); } return rowCoder; } + private static class GeneratedCoderConstructor implements Implementation { + @Override + public InstrumentedType prepare(InstrumentedType instrumentedType) { + return instrumentedType; + } + + @Override + public ByteCodeAppender appender(final Target implementationTarget) { + return (methodVisitor, implementationContext, instrumentedMethod) -> { + int numLocals = 1 + instrumentedMethod.getParameters().size(); + StackManipulation stackManipulation = + new StackManipulation.Compound( + // Call the base constructor. + MethodVariableAccess.loadThis(), + Duplication.SINGLE, + MethodInvocation.invoke( + new ForLoadedType(Coder.class) + .getDeclaredMethods() + .filter( + ElementMatchers.isConstructor().and(ElementMatchers.takesArguments(0))) + .getOnly()), + // Store the list of Coders as a member variable. + MethodVariableAccess.REFERENCE.loadFrom(1), + FieldAccess.forField( + implementationTarget + .getInstrumentedType() + .getDeclaredFields() + .filter(ElementMatchers.named(CODERS_FIELD_NAME)) + .getOnly()) + .write(), + MethodReturn.VOID); + StackManipulation.Size size = stackManipulation.apply(methodVisitor, implementationContext); + return new Size(size.getMaximalSize(), numLocals); + }; + } + } + private static DynamicType.Builder implementMethods( Schema schema, DynamicType.Builder builder) { boolean hasNullableFields = @@ -185,6 +215,7 @@ public ByteCodeAppender appender(Target implementationTarget) { StackManipulation manipulation = new StackManipulation.Compound( // Array of coders. + MethodVariableAccess.loadThis(), FieldAccess.forField( implementationContext .getInstrumentedType() @@ -272,6 +303,7 @@ public ByteCodeAppender appender(Target implementationTarget) { .filter(ElementMatchers.named("getSchema")) .getOnly()), // Array of coders. + MethodVariableAccess.loadThis(), FieldAccess.forField( implementationContext .getInstrumentedType() @@ -313,7 +345,8 @@ static Row decodeDelegate(Schema schema, Coder[] coders, InputStream inputStream if (nullFields.get(i)) { fieldValues.add(null); } else { - fieldValues.add(coders[i].decode(inputStream)); + Object fieldValue = coders[i].decode(inputStream); + fieldValues.add(fieldValue); } } } @@ -329,120 +362,4 @@ static Row decodeDelegate(Schema schema, Coder[] coders, InputStream inputStream return Row.withSchema(schema).attachValues(fieldValues).build(); } } - - private static DynamicType.Builder createComponentCoders( - Schema schema, DynamicType.Builder builder) { - List componentCoders = - Lists.newArrayListWithCapacity(schema.getFieldCount()); - for (int i = 0; i < schema.getFieldCount(); i++) { - // We use withNullable(false) as nulls are handled by the RowCoder and the individual - // component coders therefore do not need to handle nulls. - componentCoders.add(getCoder(schema.getField(i).getType().withNullable(false))); - } - - return builder - // private static final Coder[] FIELD_CODERS; - .defineField( - CODERS_FIELD_NAME, - Coder[].class, - Visibility.PRIVATE, - Ownership.STATIC, - FieldManifestation.FINAL) - // Static initializer. - .initializer( - (methodVisitor, implementationContext, instrumentedMethod) -> { - StackManipulation manipulation = - new StackManipulation.Compound( - // Initialize the array of coders. - ArrayFactory.forType(CODER_TYPE.asGenericType()).withValues(componentCoders), - FieldAccess.forField( - implementationContext - .getInstrumentedType() - .getDeclaredFields() - .filter(ElementMatchers.named(CODERS_FIELD_NAME)) - .getOnly()) - .write()); - StackManipulation.Size size = - manipulation.apply(methodVisitor, implementationContext); - return new ByteCodeAppender.Size( - size.getMaximalSize(), instrumentedMethod.getStackSize()); - }); - } - - private static StackManipulation getCoder(Schema.FieldType fieldType) { - if (TypeName.LOGICAL_TYPE.equals(fieldType.getTypeName())) { - return getCoder(fieldType.getLogicalType().getBaseType()); - } else if (TypeName.ARRAY.equals(fieldType.getTypeName())) { - return listCoder(fieldType.getCollectionElementType()); - } else if (TypeName.ITERABLE.equals(fieldType.getTypeName())) { - return iterableCoder(fieldType.getCollectionElementType()); - } - if (TypeName.MAP.equals(fieldType.getTypeName())) {; - return mapCoder(fieldType.getMapKeyType(), fieldType.getMapValueType()); - } else if (TypeName.ROW.equals(fieldType.getTypeName())) { - checkState(fieldType.getRowSchema().getUUID() != null); - Coder nestedCoder = generate(fieldType.getRowSchema()); - return rowCoder(nestedCoder.getClass()); - } else { - StackManipulation primitiveCoder = coderForPrimitiveType(fieldType.getTypeName()); - - if (fieldType.getNullable()) { - primitiveCoder = - new Compound( - primitiveCoder, - MethodInvocation.invoke( - NULLABLE_CODER - .getDeclaredMethods() - .filter(ElementMatchers.named("of")) - .getOnly())); - } - - return primitiveCoder; - } - } - - private static StackManipulation listCoder(Schema.FieldType fieldType) { - StackManipulation componentCoder = getCoder(fieldType); - return new Compound( - componentCoder, - MethodInvocation.invoke( - LIST_CODER_TYPE.getDeclaredMethods().filter(ElementMatchers.named("of")).getOnly())); - } - - private static StackManipulation iterableCoder(Schema.FieldType fieldType) { - StackManipulation componentCoder = getCoder(fieldType); - return new Compound( - componentCoder, - MethodInvocation.invoke( - ITERABLE_CODER_TYPE - .getDeclaredMethods() - .filter(ElementMatchers.named("of")) - .getOnly())); - } - - static StackManipulation coderForPrimitiveType(Schema.TypeName typeName) { - return CODER_MAP.get(typeName); - } - - static StackManipulation mapCoder(Schema.FieldType keyType, Schema.FieldType valueType) { - StackManipulation keyCoder = getCoder(keyType); - StackManipulation valueCoder = getCoder(valueType); - return new Compound( - keyCoder, - valueCoder, - MethodInvocation.invoke( - MAP_CODER_TYPE.getDeclaredMethods().filter(ElementMatchers.named("of")).getOnly())); - } - - static StackManipulation rowCoder(Class coderClass) { - ForLoadedType loadedType = new ForLoadedType(coderClass); - return new Compound( - TypeCreation.of(loadedType), - Duplication.SINGLE, - MethodInvocation.invoke( - loadedType - .getDeclaredMethods() - .filter(ElementMatchers.isConstructor().and(ElementMatchers.takesArguments(0))) - .getOnly())); - } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java index db20388c25e8..5f5d23d827fc 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java @@ -30,6 +30,7 @@ import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.Row; @@ -126,21 +127,25 @@ private ValueT fromValue( valueType, typeFactory); } else { - if (type.getTypeName().isLogicalType() - && OneOfType.IDENTIFIER.equals(type.getLogicalType().getIdentifier())) { + if (type.isLogicalType(OneOfType.IDENTIFIER)) { OneOfType oneOfType = type.getLogicalType(OneOfType.class); - OneOfType.Value oneOfValue = oneOfType.toInputType((Row) value); + EnumerationType oneOfEnum = oneOfType.getCaseEnumType(); + OneOfType.Value oneOfValue = (OneOfType.Value) value; FieldValueTypeInformation oneOfFieldValueTypeInformation = checkNotNull( - fieldValueTypeInformation.getOneOfTypes().get(oneOfValue.getCaseType().toString())); + fieldValueTypeInformation + .getOneOfTypes() + .get(oneOfEnum.toString(oneOfValue.getCaseType()))); Object fromValue = fromValue( - oneOfValue.getFieldType(), + oneOfType.getFieldType(oneOfValue), oneOfValue.getValue(), oneOfFieldValueTypeInformation.getRawType(), oneOfFieldValueTypeInformation, typeFactory); return (ValueT) oneOfType.createValue(oneOfValue.getCaseType(), fromValue); + } else if (type.getTypeName().isLogicalType()) { + return (ValueT) type.getLogicalType().toBaseType(value); } return value; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java index cf7b06f58d44..7b9bec340423 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java @@ -22,59 +22,26 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.util.Map; import java.util.Objects; import java.util.UUID; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; -import org.apache.beam.sdk.coders.BigDecimalCoder; -import org.apache.beam.sdk.coders.BigEndianShortCoder; -import org.apache.beam.sdk.coders.BooleanCoder; -import org.apache.beam.sdk.coders.ByteArrayCoder; -import org.apache.beam.sdk.coders.ByteCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CustomCoder; -import org.apache.beam.sdk.coders.DoubleCoder; -import org.apache.beam.sdk.coders.FloatCoder; -import org.apache.beam.sdk.coders.InstantCoder; -import org.apache.beam.sdk.coders.IterableCoder; -import org.apache.beam.sdk.coders.ListCoder; -import org.apache.beam.sdk.coders.MapCoder; import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.coders.RowCoderGenerator; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.coders.VarIntCoder; -import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.Schema.TypeName; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; /** {@link SchemaCoder} is used as the coder for types that have schemas registered. */ @Experimental(Kind.SCHEMAS) public class SchemaCoder extends CustomCoder { - // This contains a map of primitive types to their coders. - public static final Map CODER_MAP = - ImmutableMap.builder() - .put(TypeName.BYTE, ByteCoder.of()) - .put(TypeName.BYTES, ByteArrayCoder.of()) - .put(TypeName.INT16, BigEndianShortCoder.of()) - .put(TypeName.INT32, VarIntCoder.of()) - .put(TypeName.INT64, VarLongCoder.of()) - .put(TypeName.DECIMAL, BigDecimalCoder.of()) - .put(TypeName.FLOAT, FloatCoder.of()) - .put(TypeName.DOUBLE, DoubleCoder.of()) - .put(TypeName.STRING, StringUtf8Coder.of()) - .put(TypeName.DATETIME, InstantCoder.of()) - .put(TypeName.BOOLEAN, BooleanCoder.of()) - .build(); - protected final Schema schema; private final TypeDescriptor typeDescriptor; private final SerializableFunction toRowFunction; @@ -119,27 +86,6 @@ public static SchemaCoder of(Schema schema) { return RowCoder.of(schema); } - /** Returns the coder used for a given primitive type. */ - public static Coder coderForFieldType(FieldType fieldType) { - switch (fieldType.getTypeName()) { - case ROW: - return (Coder) SchemaCoder.of(fieldType.getRowSchema()); - case ARRAY: - return (Coder) ListCoder.of(coderForFieldType(fieldType.getCollectionElementType())); - case ITERABLE: - return (Coder) IterableCoder.of(coderForFieldType(fieldType.getCollectionElementType())); - case MAP: - return (Coder) - MapCoder.of( - coderForFieldType(fieldType.getMapKeyType()), - coderForFieldType(fieldType.getMapValueType())); - case LOGICAL_TYPE: - return coderForFieldType(fieldType.getLogicalType().getBaseType()); - default: - return (Coder) CODER_MAP.get(fieldType.getTypeName()); - } - } - /** Returns the schema associated with this type. */ public Schema getSchema() { return schema; @@ -186,7 +132,7 @@ private void verifyDeterministic(Schema schema) ImmutableList> coders = schema.getFields().stream() .map(Field::getType) - .map(SchemaCoder::coderForFieldType) + .map(SchemaCoderHelpers::coderForFieldType) .collect(ImmutableList.toImmutableList()); Coder.verifyDeterministic(this, "All fields must have deterministic encoding", coders); @@ -197,6 +143,10 @@ public boolean consistentWithEquals() { return true; } + public static Coder coderForFieldType(FieldType fieldType) { + return SchemaCoderHelpers.coderForFieldType(fieldType); + } + @Override public String toString() { return "SchemaCoder CODER_MAP = + ImmutableMap.builder() + .put(TypeName.BYTE, ByteCoder.of()) + .put(TypeName.BYTES, ByteArrayCoder.of()) + .put(TypeName.INT16, BigEndianShortCoder.of()) + .put(TypeName.INT32, VarIntCoder.of()) + .put(TypeName.INT64, VarLongCoder.of()) + .put(TypeName.DECIMAL, BigDecimalCoder.of()) + .put(TypeName.FLOAT, FloatCoder.of()) + .put(TypeName.DOUBLE, DoubleCoder.of()) + .put(TypeName.STRING, StringUtf8Coder.of()) + .put(TypeName.DATETIME, InstantCoder.of()) + .put(TypeName.BOOLEAN, BooleanCoder.of()) + .build(); + + private static class LogicalTypeCoder extends Coder { + private final LogicalType logicalType; + private final Coder baseTypeCoder; + private final boolean isDateTime; + + LogicalTypeCoder(LogicalType logicalType, Coder baseTypeCoder) { + this.logicalType = logicalType; + this.baseTypeCoder = baseTypeCoder; + this.isDateTime = logicalType.getBaseType().equals(FieldType.DATETIME); + } + + @Override + public void encode(InputT value, OutputStream outStream) throws CoderException, IOException { + BaseT baseType = logicalType.toBaseType(value); + if (isDateTime) { + baseType = (BaseT) ((ReadableInstant) baseType).toInstant(); + } + baseTypeCoder.encode(baseType, outStream); + } + + @Override + public InputT decode(InputStream inStream) throws CoderException, IOException { + BaseT baseType = baseTypeCoder.decode(inStream); + return logicalType.toInputType(baseType); + } + + @Override + public List> getCoderArguments() { + return Collections.emptyList(); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + baseTypeCoder.verifyDeterministic(); + } + + @Override + public boolean consistentWithEquals() { + // we can't assume that InputT is consistent with equals. + // TODO: We should plumb this through to logical types. + return false; + } + + @Override + public Object structuralValue(InputT value) { + if (baseTypeCoder.consistentWithEquals()) { + return logicalType.toBaseType(value); + } else { + return baseTypeCoder.structuralValue(logicalType.toBaseType(value)); + } + } + + @Override + public boolean isRegisterByteSizeObserverCheap(InputT value) { + return baseTypeCoder.isRegisterByteSizeObserverCheap(logicalType.toBaseType(value)); + } + + @Override + public void registerByteSizeObserver(InputT value, ElementByteSizeObserver observer) + throws Exception { + baseTypeCoder.registerByteSizeObserver(logicalType.toBaseType(value), observer); + } + } + + /** Returns the coder used for a given primitive type. */ + public static Coder coderForFieldType(FieldType fieldType) { + Coder coder; + switch (fieldType.getTypeName()) { + case ROW: + coder = (Coder) SchemaCoder.of(fieldType.getRowSchema()); + break; + case ARRAY: + coder = (Coder) ListCoder.of(coderForFieldType(fieldType.getCollectionElementType())); + break; + case ITERABLE: + coder = + (Coder) IterableCoder.of(coderForFieldType(fieldType.getCollectionElementType())); + break; + case MAP: + coder = + (Coder) + MapCoder.of( + coderForFieldType(fieldType.getMapKeyType()), + coderForFieldType(fieldType.getMapValueType())); + break; + case LOGICAL_TYPE: + coder = + new LogicalTypeCoder( + fieldType.getLogicalType(), + coderForFieldType(fieldType.getLogicalType().getBaseType())); + break; + default: + coder = (Coder) CODER_MAP.get(fieldType.getTypeName()); + } + Preconditions.checkNotNull(coder, "Unexpected field type " + fieldType.getTypeName()); + if (fieldType.getNullable()) { + coder = NullableCoder.of(coder); + } + return coder; + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/EnumerationType.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/EnumerationType.java index aad729b4464c..f56928a33636 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/EnumerationType.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/EnumerationType.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.schemas.logicaltypes; +import java.io.Serializable; import java.util.Arrays; import java.util.Comparator; import java.util.List; @@ -74,12 +75,12 @@ public static EnumerationType create(String... enumValues) { } /** Return an {@link Value} corresponding to one of the enumeration strings. */ public Value valueOf(String stringValue) { - return new Value(stringValue, enumValues.get(stringValue)); + return new Value(enumValues.get(stringValue)); } /** Return an {@link Value} corresponding to one of the enumeration integer values. */ public Value valueOf(int value) { - return new Value(enumValues.inverse().get(value), value); + return new Value(value); } @Override @@ -120,6 +121,10 @@ public List getValues() { return values; } + public String toString(EnumerationType.Value value) { + return enumValues.inverse().get(value.getValue()); + } + @Override public String toString() { return "Enumeration: " + enumValues; @@ -128,12 +133,10 @@ public String toString() { /** * This class represents a single enum value. It can be referenced as a String or as an integer. */ - public static class Value { - private final String stringValue; + public static class Value implements Serializable { private final int value; - public Value(String stringValue, int value) { - this.stringValue = stringValue; + public Value(int value) { this.value = value; } @@ -142,12 +145,6 @@ public int getValue() { return value; } - /** Return the String enum value. */ - @Override - public String toString() { - return stringValue; - } - @Override public boolean equals(Object o) { if (this == o) { @@ -157,12 +154,17 @@ public boolean equals(Object o) { return false; } Value enumValue = (Value) o; - return value == enumValue.value && Objects.equals(stringValue, enumValue.stringValue); + return value == enumValue.value; } @Override public int hashCode() { - return Objects.hash(stringValue, value); + return Objects.hash(value); + } + + @Override + public String toString() { + return "enum value: " + value; } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java index af2747519d8f..79c18199a97b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; @@ -129,13 +130,17 @@ public Value createValue(int caseValue, T value) { /** Create a {@link Value} specifying which field to set and the value to set. */ public Value createValue(EnumerationType.Value caseType, T value) { - return new Value(caseType, oneOfSchema.getField(caseType.toString()).getType(), value); + return new Value(caseType, value); + } + + public FieldType getFieldType(OneOfType.Value oneOneValue) { + return oneOfSchema.getField(enumerationType.toString(oneOneValue.getCaseType())).getType(); } @Override public Row toBaseType(Value input) { EnumerationType.Value caseType = input.getCaseType(); - int setFieldIndex = oneOfSchema.indexOf(caseType.toString()); + int setFieldIndex = oneOfSchema.indexOf(enumerationType.toString(caseType)); Row.Builder builder = Row.withSchema(oneOfSchema); for (int i = 0; i < oneOfSchema.getFieldCount(); ++i) { Object value = (i == setFieldIndex) ? input.getValue() : null; @@ -171,12 +176,10 @@ public String toString() { */ public static class Value { private final EnumerationType.Value caseType; - private final FieldType fieldType; private final Object value; - public Value(EnumerationType.Value caseType, FieldType fieldType, Object value) { + public Value(EnumerationType.Value caseType, Object value) { this.caseType = caseType; - this.fieldType = fieldType; this.value = value; } @@ -195,14 +198,26 @@ public Object getValue() { return value; } - /** Return the type of this union field. */ - public FieldType getFieldType() { - return fieldType; - } - @Override public String toString() { return "caseType: " + caseType + " Value: " + value; } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Value value1 = (Value) o; + return Objects.equals(caseType, value1.caseType) && Objects.equals(value, value1.value); + } + + @Override + public int hashCode() { + return Objects.hash(caseType, value); + } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java index 2f9575334722..38b0edd73050 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java @@ -168,7 +168,21 @@ public CombineFieldsGlobally agg return new CombineFieldsGlobally<>( SchemaAggregateFn.create() .aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), fn, outputFieldName)); + FieldAccessDescriptor.withFieldNames(inputFieldName), + false, + fn, + outputFieldName)); + } + + public + CombineFieldsGlobally aggregateFieldBaseValue( + String inputFieldName, + CombineFn fn, + String outputFieldName) { + return new CombineFieldsGlobally<>( + SchemaAggregateFn.create() + .aggregateFields( + FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName)); } /** The same as {@link #aggregateField} but using field id. */ @@ -179,7 +193,18 @@ public CombineFieldsGlobally agg return new CombineFieldsGlobally<>( SchemaAggregateFn.create() .aggregateFields( - FieldAccessDescriptor.withFieldIds(inputFieldId), fn, outputFieldName)); + FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName)); + } + + public + CombineFieldsGlobally aggregateFieldBaseValue( + int inputFieldId, + CombineFn fn, + String outputFieldName) { + return new CombineFieldsGlobally<>( + SchemaAggregateFn.create() + .aggregateFields( + FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName)); } /** @@ -195,7 +220,18 @@ public CombineFieldsGlobally agg return new CombineFieldsGlobally<>( SchemaAggregateFn.create() .aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), fn, outputField)); + FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField)); + } + + public + CombineFieldsGlobally aggregateFieldBaseValue( + String inputFieldName, + CombineFn fn, + Field outputField) { + return new CombineFieldsGlobally<>( + SchemaAggregateFn.create() + .aggregateFields( + FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField)); } /** The same as {@link #aggregateField} but using field id. */ @@ -203,7 +239,19 @@ public CombineFieldsGlobally agg int inputFielId, CombineFn fn, Field outputField) { return new CombineFieldsGlobally<>( SchemaAggregateFn.create() - .aggregateFields(FieldAccessDescriptor.withFieldIds(inputFielId), fn, outputField)); + .aggregateFields( + FieldAccessDescriptor.withFieldIds(inputFielId), false, fn, outputField)); + } + + public + CombineFieldsGlobally aggregateFieldBaseValue( + int inputFielId, + CombineFn fn, + Field outputField) { + return new CombineFieldsGlobally<>( + SchemaAggregateFn.create() + .aggregateFields( + FieldAccessDescriptor.withFieldIds(inputFielId), true, fn, outputField)); } /** @@ -249,7 +297,8 @@ public CombineFieldsGlobally agg CombineFn fn, String outputFieldName) { return new CombineFieldsGlobally<>( - SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, fn, outputFieldName)); + SchemaAggregateFn.create() + .aggregateFields(fieldsToAggregate, false, fn, outputFieldName)); } /** @@ -285,7 +334,7 @@ public CombineFieldsGlobally agg CombineFn fn, Field outputField) { return new CombineFieldsGlobally<>( - SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, fn, outputField)); + SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, false, fn, outputField)); } @Override @@ -341,7 +390,17 @@ public CombineFieldsGlobally agg String outputFieldName) { return new CombineFieldsGlobally<>( schemaAggregateFn.aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), fn, outputFieldName)); + FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputFieldName)); + } + + public + CombineFieldsGlobally aggregateFieldBaseValue( + String inputFieldName, + CombineFn fn, + String outputFieldName) { + return new CombineFieldsGlobally<>( + schemaAggregateFn.aggregateFields( + FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName)); } public CombineFieldsGlobally aggregateField( @@ -350,7 +409,17 @@ public CombineFieldsGlobally agg String outputFieldName) { return new CombineFieldsGlobally<>( schemaAggregateFn.aggregateFields( - FieldAccessDescriptor.withFieldIds(inputFieldId), fn, outputFieldName)); + FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName)); + } + + public + CombineFieldsGlobally aggregateFieldBaseValue( + int inputFieldId, + CombineFn fn, + String outputFieldName) { + return new CombineFieldsGlobally<>( + schemaAggregateFn.aggregateFields( + FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName)); } /** @@ -365,14 +434,34 @@ public CombineFieldsGlobally agg Field outputField) { return new CombineFieldsGlobally<>( schemaAggregateFn.aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), fn, outputField)); + FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField)); + } + + public + CombineFieldsGlobally aggregateFieldBaseValue( + String inputFieldName, + CombineFn fn, + Field outputField) { + return new CombineFieldsGlobally<>( + schemaAggregateFn.aggregateFields( + FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField)); } public CombineFieldsGlobally aggregateField( int inputFieldId, CombineFn fn, Field outputField) { return new CombineFieldsGlobally<>( schemaAggregateFn.aggregateFields( - FieldAccessDescriptor.withFieldIds(inputFieldId), fn, outputField)); + FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputField)); + } + + public + CombineFieldsGlobally aggregateFieldBaseValue( + int inputFieldId, + CombineFn fn, + Field outputField) { + return new CombineFieldsGlobally<>( + schemaAggregateFn.aggregateFields( + FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputField)); } /** @@ -417,7 +506,7 @@ public CombineFieldsGlobally agg CombineFn fn, String outputFieldName) { return new CombineFieldsGlobally<>( - schemaAggregateFn.aggregateFields(fieldAccessDescriptor, fn, outputFieldName)); + schemaAggregateFn.aggregateFields(fieldAccessDescriptor, false, fn, outputFieldName)); } /** @@ -453,7 +542,7 @@ public CombineFieldsGlobally agg CombineFn fn, Field outputField) { return new CombineFieldsGlobally<>( - schemaAggregateFn.aggregateFields(fieldAccessDescriptor, fn, outputField)); + schemaAggregateFn.aggregateFields(fieldAccessDescriptor, false, fn, outputField)); } @Override @@ -554,7 +643,21 @@ public CombineFieldsByFields agg this, SchemaAggregateFn.create() .aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), fn, outputFieldName), + FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputFieldName), + getKeyField(), + getValueField()); + } + + public + CombineFieldsByFields aggregateFieldBaseValue( + String inputFieldName, + CombineFn fn, + String outputFieldName) { + return CombineFieldsByFields.of( + this, + SchemaAggregateFn.create() + .aggregateFields( + FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName), getKeyField(), getValueField()); } @@ -567,7 +670,21 @@ public CombineFieldsByFields agg this, SchemaAggregateFn.create() .aggregateFields( - FieldAccessDescriptor.withFieldIds(inputFieldId), fn, outputFieldName), + FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName), + getKeyField(), + getValueField()); + } + + public + CombineFieldsByFields aggregateFieldBaseValue( + int inputFieldId, + CombineFn fn, + String outputFieldName) { + return CombineFieldsByFields.of( + this, + SchemaAggregateFn.create() + .aggregateFields( + FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName), getKeyField(), getValueField()); } @@ -586,7 +703,21 @@ public CombineFieldsByFields agg this, SchemaAggregateFn.create() .aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), fn, outputField), + FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField), + getKeyField(), + getValueField()); + } + + public + CombineFieldsByFields aggregateFieldBaseValue( + String inputFieldName, + CombineFn fn, + Field outputField) { + return CombineFieldsByFields.of( + this, + SchemaAggregateFn.create() + .aggregateFields( + FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField), getKeyField(), getValueField()); } @@ -596,7 +727,22 @@ public CombineFieldsByFields agg return CombineFieldsByFields.of( this, SchemaAggregateFn.create() - .aggregateFields(FieldAccessDescriptor.withFieldIds(inputFieldId), fn, outputField), + .aggregateFields( + FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputField), + getKeyField(), + getValueField()); + } + + public + CombineFieldsByFields aggregateFieldBaseValue( + int inputFieldId, + CombineFn fn, + Field outputField) { + return CombineFieldsByFields.of( + this, + SchemaAggregateFn.create() + .aggregateFields( + FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputField), getKeyField(), getValueField()); } @@ -644,7 +790,7 @@ public CombineFieldsByFields agg String outputFieldName) { return CombineFieldsByFields.of( this, - SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, fn, outputFieldName), + SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, false, fn, outputFieldName), getKeyField(), getValueField()); } @@ -683,7 +829,7 @@ public CombineFieldsByFields agg Field outputField) { return CombineFieldsByFields.of( this, - SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, fn, outputField), + SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, false, fn, outputField), getKeyField(), getValueField()); } @@ -792,7 +938,26 @@ public CombineFieldsByFields agg .setSchemaAggregateFn( getSchemaAggregateFn() .aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), fn, outputFieldName)) + FieldAccessDescriptor.withFieldNames(inputFieldName), + false, + fn, + outputFieldName)) + .build(); + } + + public + CombineFieldsByFields aggregateFieldBaseValue( + String inputFieldName, + CombineFn fn, + String outputFieldName) { + return toBuilder() + .setSchemaAggregateFn( + getSchemaAggregateFn() + .aggregateFields( + FieldAccessDescriptor.withFieldNames(inputFieldName), + true, + fn, + outputFieldName)) .build(); } @@ -804,7 +969,20 @@ public CombineFieldsByFields agg .setSchemaAggregateFn( getSchemaAggregateFn() .aggregateFields( - FieldAccessDescriptor.withFieldIds(inputFieldId), fn, outputFieldName)) + FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName)) + .build(); + } + + public + CombineFieldsByFields aggregateFieldBaseValue( + int inputFieldId, + CombineFn fn, + String outputFieldName) { + return toBuilder() + .setSchemaAggregateFn( + getSchemaAggregateFn() + .aggregateFields( + FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName)) .build(); } @@ -822,7 +1000,20 @@ public CombineFieldsByFields agg .setSchemaAggregateFn( getSchemaAggregateFn() .aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), fn, outputField)) + FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField)) + .build(); + } + + public + CombineFieldsByFields aggregateFieldBaseValue( + String inputFieldName, + CombineFn fn, + Field outputField) { + return toBuilder() + .setSchemaAggregateFn( + getSchemaAggregateFn() + .aggregateFields( + FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField)) .build(); } @@ -832,7 +1023,20 @@ public CombineFieldsByFields agg .setSchemaAggregateFn( getSchemaAggregateFn() .aggregateFields( - FieldAccessDescriptor.withFieldIds(inputFieldId), fn, outputField)) + FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputField)) + .build(); + } + + public + CombineFieldsByFields aggregateFieldBaseValue( + int inputFieldId, + CombineFn fn, + Field outputField) { + return toBuilder() + .setSchemaAggregateFn( + getSchemaAggregateFn() + .aggregateFields( + FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputField)) .build(); } @@ -870,7 +1074,7 @@ public CombineFieldsByFields agg String outputFieldName) { return toBuilder() .setSchemaAggregateFn( - getSchemaAggregateFn().aggregateFields(fieldsToAggregate, fn, outputFieldName)) + getSchemaAggregateFn().aggregateFields(fieldsToAggregate, false, fn, outputFieldName)) .build(); } @@ -908,7 +1112,7 @@ public CombineFieldsByFields agg Field outputField) { return toBuilder() .setSchemaAggregateFn( - getSchemaAggregateFn().aggregateFields(fieldsToAggregate, fn, outputField)) + getSchemaAggregateFn().aggregateFields(fieldsToAggregate, false, fn, outputField)) .build(); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaAggregateFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaAggregateFn.java index 0ee29422caca..053978219207 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaAggregateFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaAggregateFn.java @@ -31,6 +31,7 @@ import org.apache.beam.sdk.schemas.FieldTypeDescriptors; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.SchemaCoder; import org.apache.beam.sdk.schemas.utils.RowSelector; import org.apache.beam.sdk.schemas.utils.SelectHelpers; @@ -42,6 +43,7 @@ import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; /** This is the builder used by {@link Group} to build up a composed {@link CombineFn}. */ @@ -59,6 +61,7 @@ abstract static class Inner extends CombineFn { // Represents an aggregation of one or more fields. static class FieldAggregation implements Serializable { FieldAccessDescriptor fieldsToAggregate; + private final boolean aggregateBaseValues; // The specification of the output field. private final Field outputField; // The combine function. @@ -76,11 +79,13 @@ static class FieldAggregation implements Serializable { FieldAggregation( FieldAccessDescriptor fieldsToAggregate, + boolean aggregateBaseValues, Field outputField, CombineFn fn, TupleTag combineTag) { this( fieldsToAggregate, + aggregateBaseValues, outputField, fn, combineTag, @@ -90,13 +95,18 @@ static class FieldAggregation implements Serializable { FieldAggregation( FieldAccessDescriptor fieldsToAggregate, + boolean aggregateBaseValues, Field outputField, CombineFn fn, TupleTag combineTag, Schema aggregationSchema, @Nullable Schema inputSchema) { + this.aggregateBaseValues = aggregateBaseValues; if (inputSchema != null) { this.fieldsToAggregate = fieldsToAggregate.resolve(inputSchema); + if (aggregateBaseValues) { + Preconditions.checkArgument(fieldsToAggregate.referencesSingleField()); + } this.inputSubSchema = SelectHelpers.getOutputSchema(inputSchema, this.fieldsToAggregate); this.flattenedFieldAccessDescriptor = SelectHelpers.allLeavesDescriptor(inputSubSchema, SelectHelpers.CONCAT_FIELD_NAMES); @@ -120,7 +130,13 @@ static class FieldAggregation implements Serializable { // is known, resolve will be called with the proper schema. FieldAggregation resolve(Schema schema) { return new FieldAggregation<>( - fieldsToAggregate, outputField, fn, combineTag, aggregationSchema, schema); + fieldsToAggregate, + aggregateBaseValues, + outputField, + fn, + combineTag, + aggregationSchema, + schema); } } @@ -160,10 +176,17 @@ Inner withSchema(Schema inputSchema) { SimpleFunction extractFunction; Coder extractOutputCoder; if (fieldAggregation.fieldsToAggregate.referencesSingleField()) { - extractFunction = new ExtractSingleFieldFunction(inputSchema, fieldAggregation); - extractOutputCoder = - SchemaCoder.coderForFieldType( - fieldAggregation.flattenedInputSubSchema.getField(0).getType()); + extractFunction = + new ExtractSingleFieldFunction( + inputSchema, fieldAggregation.aggregateBaseValues, fieldAggregation); + + FieldType fieldType = fieldAggregation.flattenedInputSubSchema.getField(0).getType(); + if (fieldAggregation.aggregateBaseValues) { + while (fieldType.getTypeName().isLogicalType()) { + fieldType = fieldType.getLogicalType().getBaseType(); + } + } + extractOutputCoder = SchemaCoder.coderForFieldType(fieldType); } else { extractFunction = new ExtractFieldsFunction(inputSchema, fieldAggregation); extractOutputCoder = SchemaCoder.of(fieldAggregation.inputSubSchema); @@ -196,10 +219,12 @@ Inner withSchema(Schema inputSchema) { /** Aggregate all values of a set of fields into an output field. */ Inner aggregateFields( FieldAccessDescriptor fieldsToAggregate, + boolean aggregateBaseValues, CombineFn fn, String outputFieldName) { return aggregateFields( fieldsToAggregate, + aggregateBaseValues, fn, Field.of(outputFieldName, FieldTypeDescriptors.fieldTypeForJavaType(fn.getOutputType()))); } @@ -207,12 +232,14 @@ Inner aggregateFields( /** Aggregate all values of a set of fields into an output field. */ Inner aggregateFields( FieldAccessDescriptor fieldsToAggregate, + boolean aggregateBaseValues, CombineFn fn, Field outputField) { List fieldAggregations = getFieldAggregations(); TupleTag combineTag = new TupleTag<>(Integer.toString(fieldAggregations.size())); FieldAggregation fieldAggregation = - new FieldAggregation<>(fieldsToAggregate, outputField, fn, combineTag); + new FieldAggregation<>( + fieldsToAggregate, aggregateBaseValues, outputField, fn, combineTag); fieldAggregations.add(fieldAggregation); return toBuilder() @@ -232,12 +259,15 @@ private Schema getOutputSchema(List fieldAggregations) { /** Extract a single field from an input {@link Row}. */ private static class ExtractSingleFieldFunction extends SimpleFunction { private final RowSelector rowSelector; + private final boolean extractBaseValue; @Nullable private final RowSelector flatteningSelector; private final FieldAggregation fieldAggregation; - private ExtractSingleFieldFunction(Schema inputSchema, FieldAggregation fieldAggregation) { + private ExtractSingleFieldFunction( + Schema inputSchema, boolean extractBaseValue, FieldAggregation fieldAggregation) { rowSelector = new RowSelectorContainer(inputSchema, fieldAggregation.fieldsToAggregate, true); + this.extractBaseValue = extractBaseValue; flatteningSelector = fieldAggregation.needsFlattening ? new RowSelectorContainer( @@ -254,6 +284,10 @@ public OutputT apply(Row row) { if (fieldAggregation.needsFlattening) { selected = flatteningSelector.select(selected); } + if (extractBaseValue + && selected.getSchema().getField(0).getType().getTypeName().isLogicalType()) { + return (OutputT) selected.getBaseValue(0, Object.class); + } return selected.getValue(0); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AvroUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AvroUtils.java index 3d4d523bc569..13d002bf208e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AvroUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AvroUtils.java @@ -910,7 +910,8 @@ private static Object genericFromBeamField( EnumerationType enumerationType = fieldType.getLogicalType(EnumerationType.class); return GenericData.get() .createEnum( - enumerationType.valueOf((int) value).toString(), typeWithNullability.type); + enumerationType.toString((EnumerationType.Value) value), + typeWithNullability.type); default: throw new RuntimeException( "Unhandled logical type " + fieldType.getLogicalType().getIdentifier()); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java index 5aacc0e95d5b..44085c33100f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java @@ -31,6 +31,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.SortedMap; import javax.annotation.Nullable; @@ -377,8 +378,6 @@ protected Type convertPrimitive(TypeDescriptor type) { @Override protected Type convertEnum(TypeDescriptor type) { - // We represent enums in the Row as Integers. The EnumerationType handles the mapping to the - // actual enum type. return Integer.class; } @@ -577,6 +576,28 @@ public Collection values() { public Set> entrySet() { return delegateMap.entrySet(); } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TransformingMap that = (TransformingMap) o; + return Objects.equals(delegateMap, that.delegateMap); + } + + @Override + public int hashCode() { + return Objects.hash(delegateMap); + } + + @Override + public String toString() { + return delegateMap.toString(); + } } /** @@ -1445,7 +1466,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { stackManipulation, typeConversionsFactory .createSetterConversions(readParameter) - .convert(TypeDescriptor.of(parameter.getType()))); + .convert(TypeDescriptor.of(parameter.getParameterizedType()))); } stackManipulation = new StackManipulation.Compound( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java index 277652a86cad..492c747ff512 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java @@ -144,7 +144,7 @@ public static SerializableFunction getConvertPrimitive( } Type expectedInputType = - typeConversionsFactory.createTypeConversion(true).convert(outputTypeDescriptor); + typeConversionsFactory.createTypeConversion(false).convert(outputTypeDescriptor); TypeDescriptor outputType = outputTypeDescriptor; if (outputType.getRawType().isPrimitive()) { @@ -160,6 +160,7 @@ public static SerializableFunction getConvertPrimitive( .build(); DynamicType.Builder builder = (DynamicType.Builder) new ByteBuddy().subclass(genericType); + try { return builder .visit(new AsmVisitorWrapper.ForDeclaredMethods().writerFlags(ClassWriter.COMPUTE_FRAMES)) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java index 549d80ae11bb..37c3127bcad5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java @@ -431,7 +431,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { new StackManipulation.Compound( typeConversionsFactory .createGetterConversions(readValue) - .convert(TypeDescriptor.of(field.getType())), + .convert(TypeDescriptor.of(field.getGenericType())), MethodReturn.REFERENCE); StackManipulation.Size size = stackManipulation.apply(methodVisitor, implementationContext); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index dd2da2e592a7..3d37b7388d38 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -32,13 +32,15 @@ import java.util.Map; import java.util.Objects; import java.util.stream.Collector; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.schemas.Factory; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.Schema.LogicalType; +import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.TypeName; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; @@ -73,6 +75,13 @@ public abstract class Row implements Serializable { /** Return the list of data values. */ public abstract List getValues(); + /** Return a list of data values. Any LogicalType values are returned as base values. * */ + public List getBaseValues() { + return IntStream.range(0, getFieldCount()) + .mapToObj(i -> getBaseValue(i)) + .collect(Collectors.toList()); + } + /** Get value by field name, {@link ClassCastException} is thrown if type doesn't match. */ @Nullable @SuppressWarnings("TypeParameterUnusedInFormals") @@ -205,8 +214,7 @@ public Map getMap(String fieldName) { return getMap(getSchema().indexOf(fieldName)); } - /** - * Returns the Logical Type input type for this field. {@link IllegalStateException} is thrown if + /* Returns the Logical Type input type for this field. {@link IllegalStateException} is thrown if * schema doesn't match. */ @Nullable @@ -214,6 +222,24 @@ public T getLogicalTypeValue(String fieldName, Class clazz) { return getLogicalTypeValue(getSchema().indexOf(fieldName), clazz); } + /** + * Returns the base type for this field. If this is a logical type, we convert to the base value. + * Otherwise the field itself is returned. + */ + @Nullable + public T getBaseValue(String fieldName, Class clazz) { + return getBaseValue(getSchema().indexOf(fieldName), clazz); + } + + /** + * Returns the base type for this field. If this is a logical type, we convert to the base value. + * Otherwise the field itself is returned. + */ + @Nullable + public Object getBaseValue(String fieldName) { + return getBaseValue(fieldName, Object.class); + } + /** * Get a {@link TypeName#ROW} value by field name, {@link IllegalStateException} is thrown if * schema doesn't match. @@ -356,8 +382,33 @@ public Map getMap(int idx) { */ @Nullable public T getLogicalTypeValue(int idx, Class clazz) { - LogicalType logicalType = checkNotNull(getSchema().getField(idx).getType().getLogicalType()); - return (T) logicalType.toInputType(getValue(idx)); + return (T) getValue(idx); + } + + /** + * Returns the base type for this field. If this is a logical type, we convert to the base value. + * Otherwise the field itself is returned. + */ + @Nullable + public T getBaseValue(int idx, Class clazz) { + Object value = getValue(idx); + FieldType fieldType = getSchema().getField(idx).getType(); + if (fieldType.getTypeName().isLogicalType() && value != null) { + while (fieldType.getTypeName().isLogicalType()) { + value = fieldType.getLogicalType().toBaseType(value); + fieldType = fieldType.getLogicalType().getBaseType(); + } + } + return (T) value; + } + + /** + * Returns the base type for this field. If this is a logical type, we convert to the base value. + * Otherwise the field itself is returned. + */ + @Nullable + public Object getBaseValue(int idx) { + return getBaseValue(idx, Object.class); } /** diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java index 687c8fa2cc63..7a26cee17536 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java @@ -130,10 +130,11 @@ private T getValue(FieldType type, Object fieldValue, @Nullable Integer cach OneOfType oneOfType = type.getLogicalType(OneOfType.class); OneOfType.Value oneOfValue = (OneOfType.Value) fieldValue; Object convertedOneOfField = - getValue(oneOfValue.getFieldType(), oneOfValue.getValue(), null); - return (T) - oneOfType.toBaseType( - oneOfType.createValue(oneOfValue.getCaseType(), convertedOneOfField)); + getValue(oneOfType.getFieldType(oneOfValue), oneOfValue.getValue(), null); + return (T) oneOfType.createValue(oneOfValue.getCaseType(), convertedOneOfField); + } else if (type.getTypeName().isLogicalType()) { + // Getters are assumed to return the base type. + return (T) type.getLogicalType().toInputType(fieldValue); } return (T) fieldValue; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java index e06e7028bcb2..b38051075020 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java @@ -78,7 +78,8 @@ public static Object verifyFieldValue(Object value, FieldType type, String field } private static Object verifyLogicalType(Object value, LogicalType logicalType, String fieldName) { - return verifyFieldValue(logicalType.toBaseType(value), logicalType.getBaseType(), fieldName); + // TODO: this isn't guaranteed to clone the object. + return logicalType.toInputType(logicalType.toBaseType(value)); } private static List verifyArray( diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java index e3aa1ecb18d5..79202869bcaf 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java @@ -25,6 +25,9 @@ import org.apache.beam.sdk.coders.Coder.NonDeterministicException; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.Schema.LogicalType; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType.Value; import org.apache.beam.sdk.testing.CoderProperties; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; @@ -158,6 +161,73 @@ public void testIterableOfIterable() throws Exception { CoderProperties.coderDecodeEncodeEqual(RowCoder.of(schema), row); } + @Test + public void testLogicalType() throws Exception { + EnumerationType enumeration = EnumerationType.create("one", "two", "three"); + Schema schema = Schema.builder().addLogicalTypeField("f_enum", enumeration).build(); + Row row = Row.withSchema(schema).addValue(enumeration.valueOf("two")).build(); + + CoderProperties.coderDecodeEncodeEqual(RowCoder.of(schema), row); + } + + @Test + public void testLogicalTypeInCollection() throws Exception { + EnumerationType enumeration = EnumerationType.create("one", "two", "three"); + Schema schema = + Schema.builder().addArrayField("f_enum_array", FieldType.logicalType(enumeration)).build(); + Row row = + Row.withSchema(schema) + .addArray(enumeration.valueOf("two"), enumeration.valueOf("three")) + .build(); + + CoderProperties.coderDecodeEncodeEqual(RowCoder.of(schema), row); + } + + private static class NestedLogicalType implements LogicalType { + EnumerationType enumeration; + + NestedLogicalType(EnumerationType enumeration) { + this.enumeration = enumeration; + } + + @Override + public String getIdentifier() { + return ""; + } + + @Override + public FieldType getArgumentType() { + return FieldType.STRING; + } + + @Override + public FieldType getBaseType() { + return FieldType.logicalType(enumeration); + } + + @Override + public Value toBaseType(String input) { + return enumeration.valueOf(input); + } + + @Override + public String toInputType(Value base) { + return enumeration.toString(base); + } + } + + @Test + public void testNestedLogicalTypes() throws Exception { + EnumerationType enumeration = EnumerationType.create("one", "two", "three"); + Schema schema = + Schema.builder() + .addLogicalTypeField("f_nested_logical_type", new NestedLogicalType(enumeration)) + .build(); + Row row = Row.withSchema(schema).addValue("two").build(); + + CoderProperties.coderDecodeEncodeEqual(RowCoder.of(schema), row); + } + @Test(expected = NonDeterministicException.class) public void testVerifyDeterministic() throws NonDeterministicException { Schema schema = diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AvroSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AvroSchemaTest.java index e12331be1a2c..dd768a9a4b44 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AvroSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AvroSchemaTest.java @@ -477,7 +477,7 @@ public void testPojoRecordFromRowSerializable() { @Test @Category(ValidatesRunner.class) public void testAvroPipelineGroupBy() { - PCollection input = pipeline.apply(Create.of(ROW_FOR_POJO)).setRowSchema(POJO_SCHEMA); + PCollection input = pipeline.apply(Create.of(ROW_FOR_POJO).withRowSchema(POJO_SCHEMA)); PCollection output = input.apply(Group.byFieldNames("string")); Schema keySchema = Schema.builder().addStringField("string").build(); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java index 9dcdcd491389..af3aae015b75 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.schemas; +import static org.apache.beam.sdk.schemas.utils.TestPOJOs.ENUMERATION; import static org.apache.beam.sdk.schemas.utils.TestPOJOs.NESTED_ARRAYS_POJO_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestPOJOs.NESTED_ARRAY_POJO_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestPOJOs.NESTED_MAP_POJO_SCHEMA; @@ -592,20 +593,32 @@ public void testEnumFieldToRow() throws NoSuchSchemaException { SchemaRegistry registry = SchemaRegistry.createDefault(); Schema schema = registry.getSchema(PojoWithEnum.class); SchemaTestUtils.assertSchemaEquivalent(POJO_WITH_ENUM_SCHEMA, schema); - EnumerationType enumerationType = - POJO_WITH_ENUM_SCHEMA.getField(0).getType().getLogicalType(EnumerationType.class); + EnumerationType enumerationType = ENUMERATION; + List allColors = + Lists.newArrayList( + enumerationType.valueOf("RED"), + enumerationType.valueOf("GREEN"), + enumerationType.valueOf("BLUE")); Row redRow = - Row.withSchema(POJO_WITH_ENUM_SCHEMA).addValue(enumerationType.valueOf("RED")).build(); + Row.withSchema(POJO_WITH_ENUM_SCHEMA) + .addValues(enumerationType.valueOf("RED"), allColors) + .build(); Row greenRow = - Row.withSchema(POJO_WITH_ENUM_SCHEMA).addValue(enumerationType.valueOf("GREEN")).build(); + Row.withSchema(POJO_WITH_ENUM_SCHEMA) + .addValues(enumerationType.valueOf("GREEN"), allColors) + .build(); Row blueRow = - Row.withSchema(POJO_WITH_ENUM_SCHEMA).addValue(enumerationType.valueOf("BLUE")).build(); + Row.withSchema(POJO_WITH_ENUM_SCHEMA) + .addValues(enumerationType.valueOf("BLUE"), allColors) + .build(); + + List allColorsJava = Lists.newArrayList(Color.RED, Color.GREEN, Color.BLUE); SerializableFunction toRow = registry.getToRowFunction(PojoWithEnum.class); - assertEquals(redRow, toRow.apply(new PojoWithEnum(Color.RED))); - assertEquals(greenRow, toRow.apply(new PojoWithEnum(Color.GREEN))); - assertEquals(blueRow, toRow.apply(new PojoWithEnum(Color.BLUE))); + assertEquals(redRow, toRow.apply(new PojoWithEnum(Color.RED, allColorsJava))); + assertEquals(greenRow, toRow.apply(new PojoWithEnum(Color.GREEN, allColorsJava))); + assertEquals(blueRow, toRow.apply(new PojoWithEnum(Color.BLUE, allColorsJava))); } @Test @@ -613,20 +626,32 @@ public void testEnumFieldFromRow() throws NoSuchSchemaException { SchemaRegistry registry = SchemaRegistry.createDefault(); Schema schema = registry.getSchema(PojoWithEnum.class); SchemaTestUtils.assertSchemaEquivalent(POJO_WITH_ENUM_SCHEMA, schema); - EnumerationType enumerationType = - POJO_WITH_ENUM_SCHEMA.getField(0).getType().getLogicalType(EnumerationType.class); + EnumerationType enumerationType = ENUMERATION; + + List allColors = + Lists.newArrayList( + enumerationType.valueOf("RED"), + enumerationType.valueOf("GREEN"), + enumerationType.valueOf("BLUE")); Row redRow = - Row.withSchema(POJO_WITH_ENUM_SCHEMA).addValue(enumerationType.valueOf("RED")).build(); + Row.withSchema(POJO_WITH_ENUM_SCHEMA) + .addValues(enumerationType.valueOf("RED"), allColors) + .build(); Row greenRow = - Row.withSchema(POJO_WITH_ENUM_SCHEMA).addValue(enumerationType.valueOf("GREEN")).build(); + Row.withSchema(POJO_WITH_ENUM_SCHEMA) + .addValues(enumerationType.valueOf("GREEN"), allColors) + .build(); Row blueRow = - Row.withSchema(POJO_WITH_ENUM_SCHEMA).addValue(enumerationType.valueOf("BLUE")).build(); + Row.withSchema(POJO_WITH_ENUM_SCHEMA) + .addValues(enumerationType.valueOf("BLUE"), allColors) + .build(); SerializableFunction fromRow = registry.getFromRowFunction(PojoWithEnum.class); - assertEquals(new PojoWithEnum(Color.RED), fromRow.apply(redRow)); - assertEquals(new PojoWithEnum(Color.GREEN), fromRow.apply(greenRow)); - assertEquals(new PojoWithEnum(Color.BLUE), fromRow.apply(blueRow)); + List allColorsJava = Lists.newArrayList(Color.RED, Color.GREEN, Color.BLUE); + assertEquals(new PojoWithEnum(Color.RED, allColorsJava), fromRow.apply(redRow)); + assertEquals(new PojoWithEnum(Color.GREEN, allColorsJava), fromRow.apply(greenRow)); + assertEquals(new PojoWithEnum(Color.BLUE, allColorsJava), fromRow.apply(blueRow)); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/logicaltypes/LogicalTypesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/logicaltypes/LogicalTypesTest.java index 130b7acfd0b7..f52f36057a3a 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/logicaltypes/LogicalTypesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/logicaltypes/LogicalTypesTest.java @@ -38,9 +38,9 @@ public void testEnumeration() { EnumerationType enumeration = EnumerationType.create(enumMap); assertEquals(enumeration.valueOf(1), enumeration.valueOf("FIRST")); assertEquals(enumeration.valueOf(2), enumeration.valueOf("SECOND")); - assertEquals("FIRST", enumeration.valueOf(1).toString()); + assertEquals("FIRST", enumeration.toString(enumeration.valueOf(1))); assertEquals(1, enumeration.valueOf("FIRST").getValue()); - assertEquals("SECOND", enumeration.valueOf(2).toString()); + assertEquals("SECOND", enumeration.toString(enumeration.valueOf(2))); assertEquals(2, enumeration.valueOf("SECOND").getValue()); Schema schema = @@ -65,12 +65,12 @@ public void testOneOf() { Row stringOneOf = Row.withSchema(schema).addValue(oneOf.createValue("string", "stringValue")).build(); Value union = stringOneOf.getLogicalTypeValue(0, OneOfType.Value.class); - assertEquals("string", union.getCaseType().toString()); + assertEquals("string", oneOf.getCaseEnumType().toString(union.getCaseType())); assertEquals("stringValue", union.getValue()); Row intOneOf = Row.withSchema(schema).addValue(oneOf.createValue("int32", 42)).build(); union = intOneOf.getLogicalTypeValue(0, OneOfType.Value.class); - assertEquals("int32", union.getCaseType().toString()); + assertEquals("int32", oneOf.getCaseEnumType().toString(union.getCaseType())); assertEquals(42, (int) union.getValue()); } @@ -83,7 +83,7 @@ public void testNanosInstant() { Schema schema = Schema.builder().addLogicalTypeField("now", new NanosInstant()).build(); Row row = Row.withSchema(schema).addValues(now).build(); assertEquals(now, row.getLogicalTypeValue(0, NanosInstant.class)); - assertEquals(nowAsRow, row.getValue(0)); + assertEquals(nowAsRow, row.getBaseValue(0, Row.class)); } @Test @@ -95,6 +95,6 @@ public void testNanosDuration() { Schema schema = Schema.builder().addLogicalTypeField("duration", new NanosDuration()).build(); Row row = Row.withSchema(schema).addValues(duration).build(); assertEquals(duration, row.getLogicalTypeValue(0, NanosDuration.class)); - assertEquals(durationAsRow, row.getValue(0)); + assertEquals(durationAsRow, row.getBaseValue(0, Row.class)); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java index e476a32b2cc3..57c030124792 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java @@ -24,18 +24,20 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import com.google.auto.value.AutoValue; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.List; -import java.util.Objects; -import org.apache.beam.sdk.schemas.JavaFieldSchema; +import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.TypeName; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; import org.apache.beam.sdk.schemas.utils.SchemaTestUtils.RowFieldMatcherIterableFieldAnyOrder; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; @@ -44,6 +46,9 @@ import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Sample; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.Top; @@ -66,56 +71,22 @@ public class GroupTest implements Serializable { @Rule public final transient TestPipeline pipeline = TestPipeline.create(); - /** A simple POJO for testing. */ - @DefaultSchema(JavaFieldSchema.class) - public static class POJO implements Serializable { - public String field1; - public long field2; - public String field3; - - public POJO(String field1, long field2, String field3) { - this.field1 = field1; - this.field2 = field2; - this.field3 = field3; - } - - public POJO() {} + /** A basic type for testing. */ + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class Basic implements Serializable { + abstract String getField1(); - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - POJO pojo = (POJO) o; - return field2 == pojo.field2 - && Objects.equals(field1, pojo.field1) - && Objects.equals(field3, pojo.field3); - } + abstract long getField2(); - @Override - public int hashCode() { - return Objects.hash(field1, field2, field3); - } + abstract String getField3(); - @Override - public String toString() { - return "POJO{" - + "field1='" - + field1 - + '\'' - + ", field2=" - + field2 - + ", field3='" - + field3 - + '\'' - + '}'; + static Basic of(String field1, long field2, String field3) { + return new AutoValue_GroupTest_Basic(field1, field2, field3); } } - private static final Schema POJO_SCHEMA = + private static final Schema BASIC_SCHEMA = Schema.builder() .addStringField("field1") .addInt64Field("field2") @@ -129,21 +100,21 @@ public void testGroupByOneField() throws NoSuchSchemaException { pipeline .apply( Create.of( - new POJO("key1", 1, "value1"), - new POJO("key1", 2, "value2"), - new POJO("key2", 3, "value3"), - new POJO("key2", 4, "value4"))) + Basic.of("key1", 1, "value1"), + Basic.of("key1", 2, "value2"), + Basic.of("key2", 3, "value3"), + Basic.of("key2", 4, "value4"))) .apply(Group.byFieldNames("field1")); Schema keySchema = Schema.builder().addStringField("field1").build(); Schema outputSchema = Schema.builder() .addRowField("key", keySchema) - .addIterableField("value", FieldType.row(POJO_SCHEMA)) + .addIterableField("value", FieldType.row(BASIC_SCHEMA)) .build(); - SerializableFunction toRow = - pipeline.getSchemaRegistry().getToRowFunction(POJO.class); + SerializableFunction toRow = + pipeline.getSchemaRegistry().getToRowFunction(Basic.class); List expected = ImmutableList.of( @@ -151,18 +122,18 @@ public void testGroupByOneField() throws NoSuchSchemaException { .addValue(Row.withSchema(keySchema).addValue("key1").build()) .addIterable( ImmutableList.of( - toRow.apply(new POJO("key1", 1L, "value1")), - toRow.apply(new POJO("key1", 2L, "value2")))) + toRow.apply(Basic.of("key1", 1L, "value1")), + toRow.apply(Basic.of("key1", 2L, "value2")))) .build(), Row.withSchema(outputSchema) .addValue(Row.withSchema(keySchema).addValue("key2").build()) .addIterable( ImmutableList.of( - toRow.apply(new POJO("key2", 3L, "value3")), - toRow.apply(new POJO("key2", 4L, "value4")))) + toRow.apply(Basic.of("key2", 3L, "value3")), + toRow.apply(Basic.of("key2", 4L, "value4")))) .build()); - PAssert.that(grouped).satisfies(actual -> containsKIterableVs(expected, actual, new POJO[0])); + PAssert.that(grouped).satisfies(actual -> containsKIterableVs(expected, actual, new Basic[0])); pipeline.run(); } @@ -173,20 +144,20 @@ public void testGroupByMultiple() throws NoSuchSchemaException { pipeline .apply( Create.of( - new POJO("key1", 1, "value1"), - new POJO("key1", 1, "value2"), - new POJO("key2", 2, "value3"), - new POJO("key2", 2, "value4"))) + Basic.of("key1", 1, "value1"), + Basic.of("key1", 1, "value2"), + Basic.of("key2", 2, "value3"), + Basic.of("key2", 2, "value4"))) .apply(Group.byFieldNames("field1", "field2")); Schema keySchema = Schema.builder().addStringField("field1").addInt64Field("field2").build(); Schema outputSchema = Schema.builder() .addRowField("key", keySchema) - .addIterableField("value", FieldType.row(POJO_SCHEMA)) + .addIterableField("value", FieldType.row(BASIC_SCHEMA)) .build(); - SerializableFunction toRow = - pipeline.getSchemaRegistry().getToRowFunction(POJO.class); + SerializableFunction toRow = + pipeline.getSchemaRegistry().getToRowFunction(Basic.class); List expected = ImmutableList.of( @@ -194,57 +165,34 @@ public void testGroupByMultiple() throws NoSuchSchemaException { .addValue(Row.withSchema(keySchema).addValues("key1", 1L).build()) .addIterable( ImmutableList.of( - toRow.apply(new POJO("key1", 1L, "value1")), - toRow.apply(new POJO("key1", 1L, "value2")))) + toRow.apply(Basic.of("key1", 1L, "value1")), + toRow.apply(Basic.of("key1", 1L, "value2")))) .build(), Row.withSchema(outputSchema) .addValue(Row.withSchema(keySchema).addValues("key2", 2L).build()) .addIterable( ImmutableList.of( - toRow.apply(new POJO("key2", 2L, "value3")), - toRow.apply(new POJO("key2", 2L, "value4")))) + toRow.apply(Basic.of("key2", 2L, "value3")), + toRow.apply(Basic.of("key2", 2L, "value4")))) .build()); - PAssert.that(grouped).satisfies(actual -> containsKIterableVs(expected, actual, new POJO[0])); + PAssert.that(grouped).satisfies(actual -> containsKIterableVs(expected, actual, new Basic[0])); pipeline.run(); } /** A class for testing nested key grouping. */ - @DefaultSchema(JavaFieldSchema.class) - public static class OuterPOJO implements Serializable { - public POJO inner; - - public OuterPOJO(POJO inner) { - this.inner = inner; - } - - public OuterPOJO() {} - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - OuterPOJO outerPOJO = (OuterPOJO) o; - return Objects.equals(inner, outerPOJO.inner); - } - - @Override - public int hashCode() { - return Objects.hash(inner); - } + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class Outer implements Serializable { + abstract Basic getInner(); - @Override - public String toString() { - return "OuterPOJO{" + "inner=" + inner + '}'; + static Outer of(Basic inner) { + return new AutoValue_GroupTest_Outer(inner); } } - private static final Schema OUTER_POJO_SCHEMA = - Schema.builder().addRowField("inner", POJO_SCHEMA).build(); + private static final Schema OUTER_SCHEMA = + Schema.builder().addRowField("inner", BASIC_SCHEMA).build(); /** Test grouping by a set of fields that are nested. */ @Test @@ -254,21 +202,21 @@ public void testGroupByNestedKey() throws NoSuchSchemaException { pipeline .apply( Create.of( - new OuterPOJO(new POJO("key1", 1L, "value1")), - new OuterPOJO(new POJO("key1", 1L, "value2")), - new OuterPOJO(new POJO("key2", 2L, "value3")), - new OuterPOJO(new POJO("key2", 2L, "value4")))) + Outer.of(Basic.of("key1", 1L, "value1")), + Outer.of(Basic.of("key1", 1L, "value2")), + Outer.of(Basic.of("key2", 2L, "value3")), + Outer.of(Basic.of("key2", 2L, "value4")))) .apply(Group.byFieldNames("inner.field1", "inner.field2")); Schema keySchema = Schema.builder().addStringField("field1").addInt64Field("field2").build(); Schema outputSchema = Schema.builder() .addRowField("key", keySchema) - .addIterableField("value", FieldType.row(OUTER_POJO_SCHEMA)) + .addIterableField("value", FieldType.row(OUTER_SCHEMA)) .build(); - SerializableFunction toRow = - pipeline.getSchemaRegistry().getToRowFunction(OuterPOJO.class); + SerializableFunction toRow = + pipeline.getSchemaRegistry().getToRowFunction(Outer.class); List expected = ImmutableList.of( @@ -276,33 +224,32 @@ public void testGroupByNestedKey() throws NoSuchSchemaException { .addValue(Row.withSchema(keySchema).addValues("key1", 1L).build()) .addIterable( ImmutableList.of( - toRow.apply(new OuterPOJO(new POJO("key1", 1L, "value1"))), - toRow.apply(new OuterPOJO(new POJO("key1", 1L, "value2"))))) + toRow.apply(Outer.of(Basic.of("key1", 1L, "value1"))), + toRow.apply(Outer.of(Basic.of("key1", 1L, "value2"))))) .build(), Row.withSchema(outputSchema) .addValue(Row.withSchema(keySchema).addValues("key2", 2L).build()) .addIterable( ImmutableList.of( - toRow.apply(new OuterPOJO(new POJO("key2", 2L, "value3"))), - toRow.apply(new OuterPOJO(new POJO("key2", 2L, "value4"))))) + toRow.apply(Outer.of(Basic.of("key2", 2L, "value3"))), + toRow.apply(Outer.of(Basic.of("key2", 2L, "value4"))))) .build()); - PAssert.that(grouped) - .satisfies(actual -> containsKIterableVs(expected, actual, new OuterPOJO[0])); + PAssert.that(grouped).satisfies(actual -> containsKIterableVs(expected, actual, new Outer[0])); pipeline.run(); } @Test @Category(NeedsRunner.class) public void testGroupGlobally() { - Collection elements = + Collection elements = ImmutableList.of( - new POJO("key1", 1, "value1"), - new POJO("key1", 1, "value2"), - new POJO("key2", 2, "value3"), - new POJO("key2", 2, "value4")); + Basic.of("key1", 1, "value1"), + Basic.of("key1", 1, "value2"), + Basic.of("key2", 2, "value3"), + Basic.of("key2", 2, "value4")); - PCollection> grouped = + PCollection> grouped = pipeline.apply(Create.of(elements)).apply(Group.globally()); PAssert.that(grouped).satisfies(actual -> containsSingleIterable(elements, actual)); pipeline.run(); @@ -311,16 +258,16 @@ public void testGroupGlobally() { @Test @Category(NeedsRunner.class) public void testGlobalAggregation() { - Collection elements = + Collection elements = ImmutableList.of( - new POJO("key1", 1, "value1"), - new POJO("key1", 1, "value2"), - new POJO("key2", 2, "value3"), - new POJO("key2", 2, "value4")); + Basic.of("key1", 1, "value1"), + Basic.of("key1", 1, "value2"), + Basic.of("key2", 2, "value3"), + Basic.of("key2", 2, "value4")); PCollection count = pipeline .apply(Create.of(elements)) - .apply(Group.globally().aggregate(Count.combineFn())); + .apply(Group.globally().aggregate(Count.combineFn())); PAssert.that(count).containsInAnyOrder(4L); pipeline.run(); @@ -333,12 +280,12 @@ public void testOutputCoders() { Schema outputSchema = Schema.builder() .addRowField("key", keySchema) - .addIterableField("value", FieldType.row(POJO_SCHEMA)) + .addIterableField("value", FieldType.row(BASIC_SCHEMA)) .build(); PCollection grouped = pipeline - .apply(Create.of(new POJO("key1", 1, "value1"))) + .apply(Create.of(Basic.of("key1", 1, "value1"))) .apply(Group.byFieldNames("field1")); assertTrue(grouped.getSchema().equivalent(outputSchema)); @@ -347,53 +294,35 @@ public void testOutputCoders() { } /** A class for testing field aggregation. */ - @DefaultSchema(JavaFieldSchema.class) - public static class AggregatePojos implements Serializable { - public long field1; - public long field2; - public int field3; - - public AggregatePojos(long field1, long field2, int field3) { - this.field1 = field1; - this.field2 = field2; - this.field3 = field3; - } + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class Aggregate implements Serializable { + abstract long getField1(); - public AggregatePojos() {} + abstract long getField2(); - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - AggregatePojos agg = (AggregatePojos) o; - return field1 == agg.field1 && field2 == agg.field2 && field3 == agg.field3; - } + abstract int getField3(); - @Override - public int hashCode() { - return Objects.hash(field1, field2, field3); + static Aggregate of(long field1, long field2, int field3) { + return new AutoValue_GroupTest_Aggregate(field1, field2, field3); } } @Test @Category(NeedsRunner.class) public void testByKeyWithSchemaAggregateFn() { - Collection elements = + Collection elements = ImmutableList.of( - new AggregatePojos(1, 1, 2), - new AggregatePojos(2, 1, 3), - new AggregatePojos(3, 2, 4), - new AggregatePojos(4, 2, 5)); + Aggregate.of(1, 1, 2), + Aggregate.of(2, 1, 3), + Aggregate.of(3, 2, 4), + Aggregate.of(4, 2, 5)); PCollection aggregations = pipeline .apply(Create.of(elements)) .apply( - Group.byFieldNames("field2") + Group.byFieldNames("field2") .aggregateField("field1", Sum.ofLongs(), "field1_sum") .aggregateField("field3", Sum.ofIntegers(), "field3_sum") .aggregateField("field1", Top.largestLongsFn(1), "field1_top")); @@ -426,18 +355,18 @@ public void testByKeyWithSchemaAggregateFn() { @Test @Category(NeedsRunner.class) public void testGloballyWithSchemaAggregateFn() { - Collection elements = + Collection elements = ImmutableList.of( - new AggregatePojos(1, 1, 2), - new AggregatePojos(2, 1, 3), - new AggregatePojos(3, 2, 4), - new AggregatePojos(4, 2, 5)); + Aggregate.of(1, 1, 2), + Aggregate.of(2, 1, 3), + Aggregate.of(3, 2, 4), + Aggregate.of(4, 2, 5)); PCollection aggregate = pipeline .apply(Create.of(elements)) .apply( - Group.globally() + Group.globally() .aggregateField("field1", Sum.ofLongs(), "field1_sum") .aggregateField("field3", Sum.ofIntegers(), "field3_sum") .aggregateField("field1", Top.largestLongsFn(1), "field1_top")); @@ -495,19 +424,19 @@ public Long extractOutput(long[] accumulator) { @Test @Category(NeedsRunner.class) public void testAggregateByMultipleFields() { - Collection elements = + Collection elements = ImmutableList.of( - new AggregatePojos(1, 1, 2), - new AggregatePojos(2, 1, 3), - new AggregatePojos(3, 2, 4), - new AggregatePojos(4, 2, 5)); + Aggregate.of(1, 1, 2), + Aggregate.of(2, 1, 3), + Aggregate.of(3, 2, 4), + Aggregate.of(4, 2, 5)); List fieldNames = Lists.newArrayList("field1", "field2"); PCollection aggregate = pipeline .apply(Create.of(elements)) .apply( - Group.globally() + Group.globally() .aggregateFields(fieldNames, new MultipleFieldCombineFn(), "field1+field2")); Schema outputSchema = Schema.builder().addInt64Field("field1+field2").build(); @@ -518,31 +447,13 @@ public void testAggregateByMultipleFields() { } /** A class for testing nested aggregation. */ - @DefaultSchema(JavaFieldSchema.class) - public static class OuterAggregate implements Serializable { - public AggregatePojos inner; + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class OuterAggregate implements Serializable { + abstract Aggregate getInner(); - public OuterAggregate(AggregatePojos inner) { - this.inner = inner; - } - - public OuterAggregate() {} - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - OuterAggregate that = (OuterAggregate) o; - return Objects.equals(inner, that.inner); - } - - @Override - public int hashCode() { - return Objects.hash(inner); + static OuterAggregate of(Aggregate inner) { + return new AutoValue_GroupTest_OuterAggregate(inner); } } @@ -551,10 +462,10 @@ public int hashCode() { public void testByKeyWithSchemaAggregateFnNestedFields() { Collection elements = ImmutableList.of( - new OuterAggregate(new AggregatePojos(1, 1, 2)), - new OuterAggregate(new AggregatePojos(2, 1, 3)), - new OuterAggregate(new AggregatePojos(3, 2, 4)), - new OuterAggregate(new AggregatePojos(4, 2, 5))); + OuterAggregate.of(Aggregate.of(1, 1, 2)), + OuterAggregate.of(Aggregate.of(2, 1, 3)), + OuterAggregate.of(Aggregate.of(3, 2, 4)), + OuterAggregate.of(Aggregate.of(4, 2, 5))); PCollection aggregations = pipeline @@ -595,10 +506,10 @@ public void testByKeyWithSchemaAggregateFnNestedFields() { public void testGloballyWithSchemaAggregateFnNestedFields() { Collection elements = ImmutableList.of( - new OuterAggregate(new AggregatePojos(1, 1, 2)), - new OuterAggregate(new AggregatePojos(2, 1, 3)), - new OuterAggregate(new AggregatePojos(3, 2, 4)), - new OuterAggregate(new AggregatePojos(4, 2, 5))); + OuterAggregate.of(Aggregate.of(1, 1, 2)), + OuterAggregate.of(Aggregate.of(2, 1, 3)), + OuterAggregate.of(Aggregate.of(3, 2, 4)), + OuterAggregate.of(Aggregate.of(4, 2, 5))); PCollection aggregate = pipeline @@ -620,6 +531,89 @@ public void testGloballyWithSchemaAggregateFnNestedFields() { pipeline.run(); } + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class BasicEnum { + enum Test { + ZERO, + ONE, + TWO + }; + + abstract String getKey(); + + abstract Test getEnumeration(); + + static BasicEnum of(String key, Test value) { + return new AutoValue_GroupTest_BasicEnum(key, value); + } + } + + static final EnumerationType BASIC_ENUM_ENUMERATION = + EnumerationType.create("ZERO", "ONE", "TWO"); + static final Schema BASIC_ENUM_SCHEMA = + Schema.builder() + .addStringField("key") + .addLogicalTypeField("enumeration", BASIC_ENUM_ENUMERATION) + .build(); + + @Test + @Category(NeedsRunner.class) + public void testAggregateBaseValuesGlobally() { + Collection elements = + Lists.newArrayList( + BasicEnum.of("a", BasicEnum.Test.ONE), BasicEnum.of("a", BasicEnum.Test.TWO)); + + PCollection aggregate = + pipeline + .apply(Create.of(elements)) + .apply( + Group.globally() + .aggregateFieldBaseValue("enumeration", Sum.ofIntegers(), "enum_sum")); + Schema aggregateSchema = Schema.builder().addInt32Field("enum_sum").build(); + Row expectedRow = Row.withSchema(aggregateSchema).addValues(3).build(); + PAssert.that(aggregate).containsInAnyOrder(expectedRow); + + pipeline.run(); + } + + @Test + @Category(NeedsRunner.class) + public void testAggregateLogicalValuesGlobally() { + Collection elements = + Lists.newArrayList( + BasicEnum.of("a", BasicEnum.Test.ONE), BasicEnum.of("a", BasicEnum.Test.TWO)); + + CombineFn> sampleAnyCombineFn = + Sample.anyCombineFn(100); + Field aggField = + Field.of("sampleList", FieldType.array(FieldType.logicalType(BASIC_ENUM_ENUMERATION))); + pipeline + .apply(Create.of(elements)) + .apply( + Group.globally().aggregateField("enumeration", sampleAnyCombineFn, aggField)) + .apply( + ParDo.of( + new DoFn>() { + @ProcessElement + // TODO: List doesn't get converted properly by ConvertHelpers, so the + // following line does + // not work. TO fix this we need to move logical-type conversion out of + // RowWithGetters and into + // the actual getters. + // public void process(@FieldAccess("sampleList") List values) + // { + public void process(@Element Row value) { + assertThat( + value.getArray(0), + containsInAnyOrder( + BASIC_ENUM_ENUMERATION.valueOf(1), BASIC_ENUM_ENUMERATION.valueOf(2))); + } + })); + + pipeline.run(); + } + private static Void containsKIterableVs( List expectedKvs, Iterable actualKvs, T[] emptyArray) { List list = Lists.newArrayList(actualKvs); @@ -673,8 +667,8 @@ public void describeTo(Description description) { } private static Void containsSingleIterable( - Collection expected, Iterable> actual) { - POJO[] values = expected.toArray(new POJO[0]); + Collection expected, Iterable> actual) { + Basic[] values = expected.toArray(new Basic[0]); assertThat(actual, containsInAnyOrder(containsInAnyOrder(values))); return null; } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/SchemaTestUtils.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/SchemaTestUtils.java index ca5cea5d683d..7b541d0aa4b6 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/SchemaTestUtils.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/SchemaTestUtils.java @@ -21,9 +21,16 @@ import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; import static org.junit.Assert.assertTrue; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; import java.util.List; +import java.util.Map; +import java.util.Objects; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.Schema.TypeName; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; @@ -38,6 +45,135 @@ public static void assertSchemaEquivalent(Schema expected, Schema actual) { assertTrue("Expected: " + expected + " Got: " + actual, actual.equivalent(expected)); } + public static class RowEquivalent extends BaseMatcher { + private final Row expected; + + public RowEquivalent(Row expected) { + this.expected = expected; + } + + @Override + public boolean matches(Object actual) { + if (actual == null) { + return expected == null; + } + if (!(actual instanceof Row)) { + return false; + } + Row actualRow = (Row) actual; + return rowsEquivalent(expected, actualRow); + } + + private static boolean rowsEquivalent(Row expected, Row actual) { + if (!actual.getSchema().equivalent(expected.getSchema())) { + return false; + } + if (expected.getFieldCount() != actual.getFieldCount()) { + return false; + } + for (int i = 0; i < expected.getFieldCount(); ++i) { + Field field = expected.getSchema().getField(i); + int actualIndex = actual.getSchema().indexOf(field.getName()); + if (!fieldsEquivalent( + expected.getValue(i), actual.getValue(actualIndex), field.getType())) { + return false; + } + } + return true; + } + + private static boolean fieldsEquivalent(Object expected, Object actual, FieldType fieldType) { + if (expected == null || actual == null) { + return expected == actual; + } else if (fieldType.getTypeName() == TypeName.LOGICAL_TYPE) { + return fieldsEquivalent(expected, actual, fieldType.getLogicalType().getBaseType()); + } else if (fieldType.getTypeName() == Schema.TypeName.BYTES) { + return Arrays.equals((byte[]) expected, (byte[]) actual); + } else if (fieldType.getTypeName() == TypeName.ARRAY) { + return collectionsEquivalent( + (Collection) expected, + (Collection) actual, + fieldType.getCollectionElementType()); + } else if (fieldType.getTypeName() == TypeName.ITERABLE) { + return iterablesEquivalent( + (Iterable) expected, + (Iterable) actual, + fieldType.getCollectionElementType()); + } else if (fieldType.getTypeName() == Schema.TypeName.MAP) { + return mapsEquivalent( + (Map) expected, + (Map) actual, + fieldType.getMapValueType()); + } else { + return Objects.equals(expected, actual); + } + } + + static boolean collectionsEquivalent( + Collection expected, Collection actual, Schema.FieldType elementType) { + if (expected == actual) { + return true; + } + + if (expected.size() != actual.size()) { + return false; + } + + return iterablesEquivalent(expected, actual, elementType); + } + + static boolean iterablesEquivalent( + Iterable expected, Iterable actual, Schema.FieldType elementType) { + if (expected == actual) { + return true; + } + Iterator actualIter = actual.iterator(); + for (Object currentExpected : expected) { + if (!actualIter.hasNext()) { + return false; + } + if (!fieldsEquivalent(currentExpected, actualIter.next(), elementType)) { + return false; + } + } + return !actualIter.hasNext(); + } + + static boolean mapsEquivalent( + Map expected, Map actual, Schema.FieldType valueType) { + if (expected == actual) { + return true; + } + + if (expected.size() != actual.size()) { + return false; + } + + for (Map.Entry expectedElement : expected.entrySet()) { + K key = expectedElement.getKey(); + V value = expectedElement.getValue(); + V otherValue = actual.get(key); + + if (value == null) { + if (otherValue != null || !actual.containsKey(key)) { + return false; + } + } else { + if (!fieldsEquivalent(value, otherValue, valueType)) { + return false; + } + } + } + + return true; + } + + @Override + public void describeTo(Description description) { + description.appendValue(expected); + } + } + public static class RowFieldMatcherIterableFieldAnyOrder extends BaseMatcher { private final int fieldIndex; private final Object expected; @@ -67,7 +203,7 @@ public boolean matches(Object item) { return false; } Row actualRow = row.getRow(fieldIndex); - return equalTo((Row) expected).matches(actualRow); + return new RowEquivalent((Row) expected).matches(actualRow); case ARRAY: Row[] expectedArray = ((List) expected).toArray(new Row[0]); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java index be779c0dae22..ce80d29926c8 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java @@ -872,10 +872,12 @@ public enum Color { }; public final Color color; + public final List colors; @SchemaCreate - public PojoWithEnum(Color color) { + public PojoWithEnum(Color color, List colors) { this.color = color; + this.colors = colors; } @Override @@ -887,19 +889,22 @@ public boolean equals(Object o) { return false; } PojoWithEnum that = (PojoWithEnum) o; - return color == that.color; + return color == that.color && Objects.equals(colors, that.colors); } @Override public int hashCode() { - return Objects.hash(color); + return Objects.hash(color, colors); } } /** The schema for {@link PojoWithEnum}. */ + public static final EnumerationType ENUMERATION = EnumerationType.create("RED", "GREEN", "BLUE"); + public static final Schema POJO_WITH_ENUM_SCHEMA = Schema.builder() - .addLogicalTypeField("color", EnumerationType.create("RED", "GREEN", "BLUE")) + .addLogicalTypeField("color", ENUMERATION) + .addArrayField("colors", FieldType.logicalType(ENUMERATION)) .build(); /** A simple POJO containing nullable basic types. * */ diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java index ec8537cd6242..31856583058f 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java @@ -441,10 +441,10 @@ void setOnProtoMessage(Message.Builder message, Object value) { @Override Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { - Row row = (Row) value; + Instant ts = (Instant) value; return com.google.protobuf.Timestamp.newBuilder() - .setSeconds(row.getInt64(0)) - .setNanos(row.getInt32(1)) + .setSeconds(ts.getEpochSecond()) + .setNanos(ts.getNano()) .build(); } } @@ -486,10 +486,10 @@ void setOnProtoMessage(Message.Builder message, Object value) { @Override Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { - Row row = (Row) value; + Duration duration = (Duration) value; return com.google.protobuf.Duration.newBuilder() - .setSeconds(row.getInt64(0)) - .setNanos(row.getInt32(1)) + .setSeconds(duration.getSeconds()) + .setNanos(duration.getNano()) .build(); } } @@ -689,12 +689,12 @@ void setOnProtoMessage(Message.Builder message, Object value) { @Override Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { Descriptors.EnumDescriptor enumType = fieldDescriptor.getEnumType(); - return enumType.findValueByNumber((Integer) value); + return enumType.findValueByNumber(((EnumerationType.Value) value).getValue()); } } /** Convert Proto oneOf fields into the {@link OneOfType} logical type. */ - static class OneOfConvert extends Convert { + static class OneOfConvert extends Convert { OneOfType logicalType; Map oneOfConvert = new HashMap<>(); @@ -726,8 +726,7 @@ OneOfType.Value convertFromProtoValue(Object in) { } @Override - void setOnProtoMessage(Message.Builder message, Row value) { - OneOfType.Value oneOf = logicalType.toInputType(value); + void setOnProtoMessage(Message.Builder message, OneOfType.Value oneOf) { int caseIndex = oneOf.getCaseType().getValue(); oneOfConvert.get(caseIndex).setOnProtoMessage(message, oneOf.getValue()); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java index 15712e0e9ff5..1be91e9db334 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamAggregationRel.java @@ -256,8 +256,10 @@ public PCollection expand(PCollectionList pinput) { // Combining over a single field, so extract just that field. combined = (combined == null) - ? byFields.aggregateField(inputs.get(0), combineFn, fieldAggregation.outputField) - : combined.aggregateField(inputs.get(0), combineFn, fieldAggregation.outputField); + ? byFields.aggregateFieldBaseValue( + inputs.get(0), combineFn, fieldAggregation.outputField) + : combined.aggregateFieldBaseValue( + inputs.get(0), combineFn, fieldAggregation.outputField); } } @@ -327,13 +329,14 @@ static DoFn mergeRecord( @ProcessElement public void processElement( @Element Row kvRow, BoundedWindow window, OutputReceiver o) { - List fieldValues = - Lists.newArrayListWithCapacity( - kvRow.getRow(0).getValues().size() + kvRow.getRow(1).getValues().size()); + int capacity = + kvRow.getRow(0).getFieldCount() + + (!ignoreValues ? kvRow.getRow(1).getFieldCount() : 0); + List fieldValues = Lists.newArrayListWithCapacity(capacity); - fieldValues.addAll(kvRow.getRow(0).getValues()); + fieldValues.addAll(kvRow.getRow(0).getBaseValues()); if (!ignoreValues) { - fieldValues.addAll(kvRow.getRow(1).getValues()); + fieldValues.addAll(kvRow.getRow(1).getBaseValues()); } if (windowStartFieldIndex != -1) { diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java index 9434f7819815..77cd3887053d 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java @@ -29,6 +29,7 @@ import java.math.BigDecimal; import java.util.AbstractList; import java.util.AbstractMap; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Set; @@ -333,32 +334,32 @@ private static Expression castOutputTime(Expression value, FieldType toType) { } private static class InputGetterImpl implements RexToLixTranslator.InputGetter { - private static final Map TYPE_GETTER_MAP = - ImmutableMap.builder() - .put(TypeName.BYTE, "getByte") - .put(TypeName.BYTES, "getBytes") - .put(TypeName.INT16, "getInt16") - .put(TypeName.INT32, "getInt32") - .put(TypeName.INT64, "getInt64") - .put(TypeName.DECIMAL, "getDecimal") - .put(TypeName.FLOAT, "getFloat") - .put(TypeName.DOUBLE, "getDouble") - .put(TypeName.STRING, "getString") - .put(TypeName.DATETIME, "getDateTime") - .put(TypeName.BOOLEAN, "getBoolean") - .put(TypeName.MAP, "getMap") - .put(TypeName.ARRAY, "getArray") - .put(TypeName.ITERABLE, "getIterable") - .put(TypeName.ROW, "getRow") + private static final Map TYPE_CONVERSION_MAP = + ImmutableMap.builder() + .put(TypeName.BYTE, Byte.class) + .put(TypeName.BYTES, byte[].class) + .put(TypeName.INT16, Short.class) + .put(TypeName.INT32, Integer.class) + .put(TypeName.INT64, Long.class) + .put(TypeName.DECIMAL, BigDecimal.class) + .put(TypeName.FLOAT, Float.class) + .put(TypeName.DOUBLE, Double.class) + .put(TypeName.STRING, String.class) + .put(TypeName.DATETIME, ReadableInstant.class) + .put(TypeName.BOOLEAN, Boolean.class) + .put(TypeName.MAP, Map.class) + .put(TypeName.ARRAY, Collection.class) + .put(TypeName.ITERABLE, Iterable.class) + .put(TypeName.ROW, Row.class) .build(); - private static final Map LOGICAL_TYPE_GETTER_MAP = - ImmutableMap.builder() - .put(DateType.IDENTIFIER, "getDateTime") - .put(TimeType.IDENTIFIER, "getDateTime") - .put(TimeWithLocalTzType.IDENTIFIER, "getDateTime") - .put(TimestampWithLocalTzType.IDENTIFIER, "getDateTime") - .put(CharType.IDENTIFIER, "getString") + private static final Map LOGICAL_TYPE_CONVERSION_MAP = + ImmutableMap.builder() + .put(DateType.IDENTIFIER, ReadableInstant.class) + .put(TimeType.IDENTIFIER, ReadableInstant.class) + .put(TimeWithLocalTzType.IDENTIFIER, ReadableInstant.class) + .put(TimestampWithLocalTzType.IDENTIFIER, ReadableInstant.class) + .put(CharType.IDENTIFIER, String.class) .build(); private final Expression input; @@ -381,24 +382,29 @@ private static Expression value( } final Expression expression = list.append(list.newName("current"), input); - if (storageType == Object.class) { - return Expressions.convert_( - Expressions.call(expression, "getValue", Expressions.constant(index)), Object.class); - } + FieldType fromType = schema.getField(index).getType(); - String getter; - if (fromType.getTypeName().isLogicalType()) { - getter = LOGICAL_TYPE_GETTER_MAP.get(fromType.getLogicalType().getIdentifier()); + Class convertTo = null; + if (storageType == Object.class) { + convertTo = Object.class; + } else if (fromType.getTypeName().isLogicalType()) { + convertTo = LOGICAL_TYPE_CONVERSION_MAP.get(fromType.getLogicalType().getIdentifier()); } else { - getter = TYPE_GETTER_MAP.get(fromType.getTypeName()); + convertTo = TYPE_CONVERSION_MAP.get(fromType.getTypeName()); } - if (getter == null) { + if (convertTo == null) { throw new UnsupportedOperationException("Unable to get " + fromType.getTypeName()); } - Expression value = Expressions.call(expression, getter, Expressions.constant(index)); - - return value(value, fromType); + Expression value = + Expressions.convert_( + Expressions.call( + expression, + "getBaseValue", + Expressions.constant(index), + Expressions.constant(convertTo)), + convertTo); + return (storageType != Object.class) ? value(value, fromType) : value; } private static Expression value(Expression value, Schema.FieldType type) { diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamEnumerableConverter.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamEnumerableConverter.java index 88032fbae389..db32362e47ed 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamEnumerableConverter.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamEnumerableConverter.java @@ -287,7 +287,7 @@ private static Object[] rowToAvatica(Row row) { Object[] convertedColumns = new Object[schema.getFields().size()]; int i = 0; for (Schema.Field field : schema.getFields()) { - convertedColumns[i] = fieldToAvatica(field.getType(), row.getValue(i)); + convertedColumns[i] = fieldToAvatica(field.getType(), row.getBaseValue(i, Object.class)); ++i; } return convertedColumns; diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java index afe18b044d28..ec6582083234 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java @@ -345,8 +345,8 @@ public int compare(Row row1, Row row2) { case VARCHAR: case DATE: case TIMESTAMP: - Comparable v1 = (Comparable) row1.getValue(fieldIndex); - Comparable v2 = (Comparable) row2.getValue(fieldIndex); + Comparable v1 = row1.getBaseValue(fieldIndex, Comparable.class); + Comparable v2 = row2.getBaseValue(fieldIndex, Comparable.class); fieldRet = v1.compareTo(v2); break; default: diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUncollectRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUncollectRel.java index 2b2511dac90d..24545b616698 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUncollectRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUncollectRel.java @@ -103,7 +103,7 @@ public void process(@Element Row inputRow, OutputReceiver output) { for (Object element : inputRow.getArray(0)) { if (element instanceof Row) { Row nestedRow = (Row) element; - output.output(Row.withSchema(schema).addValues(nestedRow.getValues()).build()); + output.output(Row.withSchema(schema).addValues(nestedRow.getBaseValues()).build()); } else { output.output(Row.withSchema(schema).addValue(element).build()); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java index 9f27a4ad9375..db38ce736a96 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java @@ -142,13 +142,13 @@ public void process(@Element Row row, OutputReceiver out) { Row nestedRow = (Row) uncollectedValue; out.output( Row.withSchema(outputSchema) - .addValues(row.getValues()) - .addValues(nestedRow.getValues()) + .addValues(row.getBaseValues()) + .addValues(nestedRow.getBaseValues()) .build()); } else { out.output( Row.withSchema(outputSchema) - .addValues(row.getValues()) + .addValues(row.getBaseValues()) .addValue(uncollectedValue) .build()); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamTableUtils.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamTableUtils.java index 29d82d3930af..f3c2704cdbc4 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamTableUtils.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamTableUtils.java @@ -84,7 +84,7 @@ public static String beamRow2CsvLine(Row row, CSVFormat csvFormat) { StringWriter writer = new StringWriter(); try (CSVPrinter printer = csvFormat.print(writer)) { for (int i = 0; i < row.getFieldCount(); i++) { - printer.print(row.getValue(i).toString()); + printer.print(row.getBaseValue(i, Object.class).toString()); } printer.println(); } catch (IOException e) { diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamJoinTransforms.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamJoinTransforms.java index cf8ffd258bb1..0aa74a2277d2 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamJoinTransforms.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamJoinTransforms.java @@ -86,8 +86,8 @@ private static Row combineTwoRowsIntoOne( /** As the method name suggests: combine two rows into one wide row. */ private static Row combineTwoRowsIntoOneHelper(Row leftRow, Row rightRow, Schema ouputSchema) { return Row.withSchema(ouputSchema) - .addValues(leftRow.getValues()) - .addValues(rightRow.getValues()) + .addValues(leftRow.getBaseValues()) + .addValues(rightRow.getBaseValues()) .build(); } @@ -170,7 +170,9 @@ public void teardown() { private Row extractJoinSubRow(Row factRow) { List joinSubsetValues = - factJoinIdx.stream().map(factRow::getValue).collect(toList()); + factJoinIdx.stream() + .map(i -> factRow.getBaseValue(i, Object.class)) + .collect(toList()); return Row.withSchema(joinSubsetType).addValues(joinSubsetValues).build(); } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CovarianceFn.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CovarianceFn.java index 825aad8a6a3f..ac7edfec1ae0 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CovarianceFn.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/CovarianceFn.java @@ -94,8 +94,8 @@ public CovarianceAccumulator addInput(CovarianceAccumulator currentVariance, Row return currentVariance.combineWith( CovarianceAccumulator.ofSingleElement( - SqlFunctions.toBigDecimal((Object) rawInput.getValue(0)), - SqlFunctions.toBigDecimal((Object) rawInput.getValue(1)))); + SqlFunctions.toBigDecimal((Object) rawInput.getBaseValue(0, Object.class)), + SqlFunctions.toBigDecimal((Object) rawInput.getBaseValue(1, Object.class)))); } @Override diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamComplexTypeTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamComplexTypeTest.java index 103de8e29815..ce46c17c29e9 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamComplexTypeTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamComplexTypeTest.java @@ -396,8 +396,7 @@ public void testLogicalTypes() { PCollection outputRow = pipeline - .apply(Create.of(row)) - .setRowSchema(outputRowSchema) + .apply(Create.of(row).withRowSchema(inputRowSchema)) .apply( SqlTransform.query( "SELECT timeTypeField, dateTypeField FROM PCOLLECTION GROUP BY timeTypeField, dateTypeField")); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamSqlRowCoderTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamSqlRowCoderTest.java index 3624d12708ed..f83ad7cf7ff1 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamSqlRowCoderTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/schema/BeamSqlRowCoderTest.java @@ -65,8 +65,8 @@ public void encodeAndDecode() throws Exception { 1.1, BigDecimal.ZERO, "hello", - DateTime.now(), - DateTime.now(), + DateTime.now().toInstant(), + DateTime.now().toInstant(), true) .build(); Coder coder = SchemaCoder.of(beamSchema); diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java index 94a39f23caa6..0d6d25baa9bc 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java @@ -174,7 +174,7 @@ public void processElement(ProcessContext c) { columns.put( columnName(i), ZetaSqlUtils.javaObjectToZetaSqlValue( - row.getValue(i), inputSchema.getField(i).getType())); + row.getBaseValue(i, Object.class), inputSchema.getField(i).getType())); } // TODO[BEAM-8630]: support parameters in expression evaluation diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlUtils.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlUtils.java index e60e330c2b46..2ee1691312ae 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlUtils.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlUtils.java @@ -181,7 +181,9 @@ public static Value beamRowToZetaSqlStructValue(Row row, Schema schema) { List values = new ArrayList<>(row.getFieldCount()); for (int i = 0; i < row.getFieldCount(); i++) { - values.add(javaObjectToZetaSqlValue(row.getValue(i), schema.getField(i).getType())); + values.add( + javaObjectToZetaSqlValue( + row.getBaseValue(i, Object.class), schema.getField(i).getType())); } return Value.createStructValue(createZetaSqlStructTypeFromBeamSchema(schema), values); }