diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueSetter.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueSetter.java index 5d9e82bf24a6..db7caaa0535d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueSetter.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueSetter.java @@ -26,7 +26,7 @@ * *

An interface to set a field of a class. * - *

Implementations of this interface are generated at runtime to map Row fields back into objet + *

Implementations of this interface are generated at runtime to map Row fields back into object * fields. */ @Internal diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java index a6ecc4500bd8..33ed888cf6cc 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java @@ -22,7 +22,10 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.Arrays; +import java.util.Collections; +import java.util.Map; import javax.annotation.Nullable; +import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; import org.apache.beam.sdk.schemas.utils.ReflectUtils; import org.apache.beam.sdk.values.TypeDescriptor; @@ -47,6 +50,8 @@ public abstract class FieldValueTypeInformation implements Serializable { @Nullable public abstract Method getMethod(); + public abstract Map getOneOfTypes(); + /** If the field is a container type, returns the element type. */ @Nullable public abstract FieldValueTypeInformation getElementType(); @@ -62,7 +67,7 @@ public abstract class FieldValueTypeInformation implements Serializable { abstract Builder toBuilder(); @AutoValue.Builder - abstract static class Builder { + public abstract static class Builder { public abstract Builder setName(String name); public abstract Builder setNullable(boolean nullable); @@ -75,6 +80,8 @@ abstract static class Builder { public abstract Builder setMethod(@Nullable Method method); + public abstract Builder setOneOfTypes(Map oneOfTypes); + public abstract Builder setElementType(@Nullable FieldValueTypeInformation elementType); public abstract Builder setMapKeyType(@Nullable FieldValueTypeInformation mapKeyType); @@ -84,6 +91,22 @@ abstract static class Builder { abstract FieldValueTypeInformation build(); } + public static FieldValueTypeInformation forOneOf( + String name, boolean nullable, Map oneOfTypes) { + final TypeDescriptor typeDescriptor = TypeDescriptor.of(OneOfType.Value.class); + return new AutoValue_FieldValueTypeInformation.Builder() + .setName(name) + .setNullable(nullable) + .setType(typeDescriptor) + .setRawType(typeDescriptor.getRawType()) + .setField(null) + .setElementType(null) + .setMapKeyType(null) + .setMapValueType(null) + .setOneOfTypes(oneOfTypes) + .build(); + } + public static FieldValueTypeInformation forField(Field field) { TypeDescriptor type = TypeDescriptor.of(field.getGenericType()); return new AutoValue_FieldValueTypeInformation.Builder() @@ -95,6 +118,7 @@ public static FieldValueTypeInformation forField(Field field) { .setElementType(getIterableComponentType(field)) .setMapKeyType(getMapKeyType(field)) .setMapValueType(getMapValueType(field)) + .setOneOfTypes(Collections.emptyMap()) .build(); } @@ -119,6 +143,7 @@ public static FieldValueTypeInformation forGetter(Method method) { .setElementType(getIterableComponentType(type)) .setMapKeyType(getMapKeyType(type)) .setMapValueType(getMapValueType(type)) + .setOneOfTypes(Collections.emptyMap()) .build(); } @@ -148,6 +173,7 @@ public static FieldValueTypeInformation forSetter(Method method, String setterPr .setElementType(getIterableComponentType(type)) .setMapKeyType(getMapKeyType(type)) .setMapValueType(getMapValueType(type)) + .setOneOfTypes(Collections.emptyMap()) .build(); } @@ -175,6 +201,7 @@ static FieldValueTypeInformation getIterableComponentType(TypeDescriptor valueTy .setElementType(getIterableComponentType(componentType)) .setMapKeyType(getMapKeyType(componentType)) .setMapValueType(getMapValueType(componentType)) + .setOneOfTypes(Collections.emptyMap()) .build(); } @@ -217,6 +244,7 @@ private static FieldValueTypeInformation getMapType(TypeDescriptor valueType, in .setElementType(getIterableComponentType(mapType)) .setMapKeyType(getMapKeyType(mapType)) .setMapValueType(getMapValueType(mapType)) + .setOneOfTypes(Collections.emptyMap()) .build(); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java index 61c0d0520d3c..2d4bdfc7a655 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java @@ -28,6 +28,7 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.RowWithGetters; @@ -80,13 +81,7 @@ public ValueT fromRow( FieldValueTypeInformation typeInformation = checkNotNull(typeInformations.get(i)); params[i] = fromValue( - type, - row.getValue(i), - typeInformation.getRawType(), - typeInformation.getElementType(), - typeInformation.getMapKeyType(), - typeInformation.getMapValueType(), - typeFactory); + type, row.getValue(i), typeInformation.getRawType(), typeInformation, typeFactory); } SchemaUserTypeCreator creator = schemaTypeCreatorFactory.create(clazz, schema); @@ -99,10 +94,11 @@ private ValueT fromValue( FieldType type, ValueT value, Type fieldType, - FieldValueTypeInformation elementType, - FieldValueTypeInformation keyType, - FieldValueTypeInformation valueType, + FieldValueTypeInformation fieldValueTypeInformation, Factory> typeFactory) { + FieldValueTypeInformation elementType = fieldValueTypeInformation.getElementType(); + FieldValueTypeInformation keyType = fieldValueTypeInformation.getMapKeyType(); + FieldValueTypeInformation valueType = fieldValueTypeInformation.getMapValueType(); if (value == null) { return null; } @@ -127,6 +123,24 @@ private ValueT fromValue( valueType, typeFactory); } else { + if (type.isLogicalType(OneOfType.IDENTIFIER)) { + // If this is a OneOf union type, we must extract the current union value and convert that + // value to the Java + // type expected by the creator object. + OneOfType oneOfType = type.getLogicalType(OneOfType.class); + OneOfType.Value oneOfValue = oneOfType.toInputType((Row) value); + FieldValueTypeInformation oneOfFieldValueTypeInformation = + checkNotNull( + fieldValueTypeInformation.getOneOfTypes().get(oneOfValue.getCaseType().toString())); + Object fromValue = + fromValue( + oneOfValue.getFieldType(), + oneOfValue.getValue(), + oneOfFieldValueTypeInformation.getRawType(), + oneOfFieldValueTypeInformation, + typeFactory); + return (ValueT) oneOfType.createValue(oneOfValue.getCaseType(), fromValue); + } return value; } } @@ -156,9 +170,7 @@ private Collection fromCollectionValue( elementType, element, elementTypeInformation.getType().getType(), - elementTypeInformation.getElementType(), - elementTypeInformation.getMapKeyType(), - elementTypeInformation.getMapValueType(), + elementTypeInformation, typeFactory)); } @@ -175,9 +187,7 @@ private Iterable fromIterableValue( elementType, element, elementTypeInformation.getType().getType(), - elementTypeInformation.getElementType(), - elementTypeInformation.getMapKeyType(), - elementTypeInformation.getMapValueType(), + elementTypeInformation, typeFactory)); } @@ -196,18 +206,14 @@ private Iterable fromIterableValue( keyType, entry.getKey(), keyTypeInformation.getType().getType(), - keyTypeInformation.getElementType(), - keyTypeInformation.getMapKeyType(), - keyTypeInformation.getMapValueType(), + keyTypeInformation, typeFactory); Object value = fromValue( valueType, entry.getValue(), valueTypeInformation.getType().getType(), - valueTypeInformation.getElementType(), - valueTypeInformation.getMapKeyType(), - valueTypeInformation.getMapValueType(), + valueTypeInformation, typeFactory); newMap.put(key, value); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java index 5176def497e2..5233b0c1e535 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java @@ -40,7 +40,7 @@ public class Convert { * Convert a {@link PCollection}{@literal } into a {@link PCollection}{@literal }. * *

The input {@link PCollection} must have a schema attached. The output collection will have - * the same schema as the iput. + * the same schema as the input. */ public static PTransform, PCollection> toRows() { return to(Row.class); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AvroByteBuddyUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AvroByteBuddyUtils.java index fd7f6013be85..436da6cf4d4d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AvroByteBuddyUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AvroByteBuddyUtils.java @@ -54,7 +54,7 @@ class AvroByteBuddyUtils { static SchemaUserTypeCreator getCreator( Class clazz, Schema schema) { return CACHED_CREATORS.computeIfAbsent( - new ClassWithSchema(clazz, schema), c -> createCreator(clazz, schema)); + ClassWithSchema.create(clazz, schema), c -> createCreator(clazz, schema)); } private static SchemaUserTypeCreator createCreator(Class clazz, Schema schema) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java index 791dafbb40f0..c00d5d093fdf 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ByteBuddyUtils.java @@ -94,7 +94,6 @@ public class ByteBuddyUtils { new ForLoadedType(ReadableInstant.class); private static final ForLoadedType READABLE_PARTIAL_TYPE = new ForLoadedType(ReadablePartial.class); - private static final ForLoadedType OBJECT_TYPE = new ForLoadedType(Object.class); private static final ForLoadedType INTEGER_TYPE = new ForLoadedType(Integer.class); private static final ForLoadedType ENUM_TYPE = new ForLoadedType(Enum.class); private static final ForLoadedType BYTE_BUDDY_UTILS_TYPE = @@ -134,7 +133,7 @@ protected String name(TypeDescription superClass) { // Create a new FieldValueGetter subclass. @SuppressWarnings("unchecked") - static DynamicType.Builder subclassGetterInterface( + public static DynamicType.Builder subclassGetterInterface( ByteBuddy byteBuddy, Type objectType, Type fieldType) { TypeDescription.Generic getterGenericType = TypeDescription.Generic.Builder.parameterizedType( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java index 759d77d25c3d..e25342bfc5e8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/JavaBeanUtils.java @@ -102,7 +102,7 @@ public static void validateJavaBean( public static List getFieldTypes( Class clazz, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier) { return CACHED_FIELD_TYPES.computeIfAbsent( - new ClassWithSchema(clazz, schema), c -> fieldValueTypeSupplier.get(clazz, schema)); + ClassWithSchema.create(clazz, schema), c -> fieldValueTypeSupplier.get(clazz, schema)); } // The list of getters for a class is cached, so we only create the classes the first time @@ -121,7 +121,7 @@ public static List getGetters( FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { return CACHED_GETTERS.computeIfAbsent( - new ClassWithSchema(clazz, schema), + ClassWithSchema.create(clazz, schema), c -> { List types = fieldValueTypeSupplier.get(clazz, schema); return types.stream() @@ -130,7 +130,7 @@ public static List getGetters( }); } - private static FieldValueGetter createGetter( + public static FieldValueGetter createGetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { DynamicType.Builder builder = ByteBuddyUtils.subclassGetterInterface( @@ -184,7 +184,7 @@ public static List getSetters( FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { return CACHED_SETTERS.computeIfAbsent( - new ClassWithSchema(clazz, schema), + ClassWithSchema.create(clazz, schema), c -> { List types = fieldValueTypeSupplier.get(clazz, schema); return types.stream() @@ -193,14 +193,14 @@ public static List getSetters( }); } - private static FieldValueSetter createSetter( + public static FieldValueSetter createSetter( FieldValueTypeInformation typeInformation, TypeConversionsFactory typeConversionsFactory) { DynamicType.Builder builder = ByteBuddyUtils.subclassSetterInterface( BYTE_BUDDY, typeInformation.getMethod().getDeclaringClass(), typeConversionsFactory.createTypeConversion(false).convert(typeInformation.getType())); - builder = implementSetterMethods(builder, typeInformation.getMethod(), typeConversionsFactory); + builder = implementSetterMethods(builder, typeInformation, typeConversionsFactory); try { return builder .make() @@ -222,14 +222,13 @@ private static FieldValueSetter createSetter( private static DynamicType.Builder implementSetterMethods( DynamicType.Builder builder, - Method method, + FieldValueTypeInformation fieldValueTypeInformation, TypeConversionsFactory typeConversionsFactory) { - FieldValueTypeInformation javaTypeInformation = FieldValueTypeInformation.forSetter(method); return builder .method(ElementMatchers.named("name")) - .intercept(FixedValue.reference(javaTypeInformation.getName())) + .intercept(FixedValue.reference(fieldValueTypeInformation.getName())) .method(ElementMatchers.named("set")) - .intercept(new InvokeSetterInstruction(method, typeConversionsFactory)); + .intercept(new InvokeSetterInstruction(fieldValueTypeInformation, typeConversionsFactory)); } // The list of constructors for a class is cached, so we only create the classes the first time @@ -244,7 +243,7 @@ public static SchemaUserTypeCreator getConstructorCreator( FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { return CACHED_CREATORS.computeIfAbsent( - new ClassWithSchema(clazz, schema), + ClassWithSchema.create(clazz, schema), c -> { List types = fieldValueTypeSupplier.get(clazz, schema); return createConstructorCreator( @@ -291,7 +290,7 @@ public static SchemaUserTypeCreator getStaticCreator( FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { return CACHED_CREATORS.computeIfAbsent( - new ClassWithSchema(clazz, schema), + ClassWithSchema.create(clazz, schema), c -> { List types = fieldValueTypeSupplier.get(clazz, schema); return createStaticCreator(clazz, creator, schema, types, typeConversionsFactory); @@ -377,11 +376,13 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Implements a method to write a public set out on an object. private static class InvokeSetterInstruction implements Implementation { // Setter method that will be invoked - private Method method; + private FieldValueTypeInformation fieldValueTypeInformation; private final TypeConversionsFactory typeConversionsFactory; - InvokeSetterInstruction(Method method, TypeConversionsFactory typeConversionsFactory) { - this.method = method; + InvokeSetterInstruction( + FieldValueTypeInformation fieldValueTypeInformation, + TypeConversionsFactory typeConversionsFactory) { + this.fieldValueTypeInformation = fieldValueTypeInformation; this.typeConversionsFactory = typeConversionsFactory; } @@ -393,13 +394,13 @@ public InstrumentedType prepare(InstrumentedType instrumentedType) { @Override public ByteCodeAppender appender(final Target implementationTarget) { return (methodVisitor, implementationContext, instrumentedMethod) -> { - FieldValueTypeInformation javaTypeInformation = FieldValueTypeInformation.forSetter(method); // this + method parameters. int numLocals = 1 + instrumentedMethod.getParameters().size(); // The instruction to read the field. StackManipulation readField = MethodVariableAccess.REFERENCE.loadFrom(2); + Method method = fieldValueTypeInformation.getMethod(); boolean setterMethodReturnsVoid = method.getReturnType().equals(Void.TYPE); // Read the object onto the stack. StackManipulation stackManipulation = @@ -409,7 +410,7 @@ public ByteCodeAppender appender(final Target implementationTarget) { // Do any conversions necessary. typeConversionsFactory .createSetterConversions(readField) - .convert(javaTypeInformation.getType()), + .convert(fieldValueTypeInformation.getType()), // Now update the field and return void. MethodInvocation.invoke(new ForLoadedMethod(method))); if (!setterMethodReturnsVoid) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java index a58ddf8b8a40..aa968b415279 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/POJOUtils.java @@ -81,7 +81,7 @@ public static Schema schemaFromPojoClass( public static List getFieldTypes( Class clazz, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier) { return CACHED_FIELD_TYPES.computeIfAbsent( - new ClassWithSchema(clazz, schema), c -> fieldValueTypeSupplier.get(clazz, schema)); + ClassWithSchema.create(clazz, schema), c -> fieldValueTypeSupplier.get(clazz, schema)); } // The list of getters for a class is cached, so we only create the classes the first time @@ -96,7 +96,7 @@ public static List getGetters( TypeConversionsFactory typeConversionsFactory) { // Return the getters ordered by their position in the schema. return CACHED_GETTERS.computeIfAbsent( - new ClassWithSchema(clazz, schema), + ClassWithSchema.create(clazz, schema), c -> { List types = fieldValueTypeSupplier.get(clazz, schema); List getters = @@ -122,7 +122,7 @@ public static SchemaUserTypeCreator getSetFieldCreator( FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { return CACHED_CREATORS.computeIfAbsent( - new ClassWithSchema(clazz, schema), + ClassWithSchema.create(clazz, schema), c -> { List types = fieldValueTypeSupplier.get(clazz, schema); return createSetFieldCreator(clazz, schema, types, typeConversionsFactory); @@ -169,7 +169,7 @@ public static SchemaUserTypeCreator getConstructorCreator( FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { return CACHED_CREATORS.computeIfAbsent( - new ClassWithSchema(clazz, schema), + ClassWithSchema.create(clazz, schema), c -> { List types = fieldValueTypeSupplier.get(clazz, schema); return createConstructorCreator( @@ -217,7 +217,7 @@ public static SchemaUserTypeCreator getStaticCreator( FieldValueTypeSupplier fieldValueTypeSupplier, TypeConversionsFactory typeConversionsFactory) { return CACHED_CREATORS.computeIfAbsent( - new ClassWithSchema(clazz, schema), + ClassWithSchema.create(clazz, schema), c -> { List types = fieldValueTypeSupplier.get(clazz, schema); return createStaticCreator(clazz, creator, schema, types, typeConversionsFactory); @@ -323,7 +323,7 @@ public static List getSetters( TypeConversionsFactory typeConversionsFactory) { // Return the setters, ordered by their position in the schema. return CACHED_SETTERS.computeIfAbsent( - new ClassWithSchema(clazz, schema), + ClassWithSchema.create(clazz, schema), c -> { List types = fieldValueTypeSupplier.get(clazz, schema); return types.stream() diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java index d56f0bd152f4..08c494c30c52 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ReflectUtils.java @@ -19,6 +19,7 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; +import com.google.auto.value.AutoValue; import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.Method; @@ -31,7 +32,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.beam.sdk.schemas.Schema; @@ -39,35 +39,21 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimaps; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Primitives; /** A set of reflection helper methods. */ public class ReflectUtils { /** Represents a class and a schema. */ - public static class ClassWithSchema { - private final Class clazz; - private final Schema schema; + @AutoValue + public abstract static class ClassWithSchema { + public abstract Class getClazz(); - public ClassWithSchema(Class clazz, Schema schema) { - this.clazz = clazz; - this.schema = schema; - } + public abstract Schema getSchema(); - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ClassWithSchema that = (ClassWithSchema) o; - return Objects.equals(clazz, that.clazz) && Objects.equals(schema, that.schema); - } - - @Override - public int hashCode() { - return Objects.hash(clazz, schema); + public static ClassWithSchema create(Class clazz, Schema schema) { + return new AutoValue_ReflectUtils_ClassWithSchema(clazz, schema); } } @@ -94,6 +80,10 @@ public static List getMethods(Class clazz) { }); } + public static Multimap getMethodsMap(Class clazz) { + return Multimaps.index(getMethods(clazz), Method::getName); + } + @Nullable public static Constructor getAnnotatedConstructor(Class clazz) { return Arrays.stream(clazz.getDeclaredConstructors()) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java index ebf59b9216fc..399824ea5f9b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java @@ -29,6 +29,7 @@ import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Collections2; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; @@ -122,6 +123,20 @@ private T getValue(FieldType type, Object fieldValue, @Nullable Integer cach cacheKey, i -> getMapValue(type.getMapKeyType(), type.getMapValueType(), map)) : (T) getMapValue(type.getMapKeyType(), type.getMapValueType(), map); } else { + if (type.isLogicalType(OneOfType.IDENTIFIER)) { + // If this is a OneOf union type, we must extract the oneOfType corresponding to the current + // union value. + OneOfType oneOfType = type.getLogicalType(OneOfType.class); + OneOfType.Value oneOfValue = (OneOfType.Value) fieldValue; + Object convertedOneOfField = + getValue(oneOfValue.getFieldType(), oneOfValue.getValue(), null); + // Row.getValue by default returns the base representation type of logical types (for OneOf + // a Row with nullable + // fields for each option). Convert the result to the base type before returning. + return (T) + oneOfType.toBaseType( + oneOfType.createValue(oneOfValue.getCaseType(), convertedOneOfField)); + } return (T) fieldValue; } } diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java new file mode 100644 index 000000000000..308663941def --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java @@ -0,0 +1,600 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getFieldNumber; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Duration; +import com.google.protobuf.Internal.EnumLite; +import com.google.protobuf.MessageLite; +import com.google.protobuf.ProtocolMessageEnum; +import com.google.protobuf.Timestamp; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Type; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import javax.annotation.Nullable; +import org.apache.beam.sdk.extensions.protobuf.ProtoSchemaLogicalTypes.DurationNanos; +import org.apache.beam.sdk.extensions.protobuf.ProtoSchemaLogicalTypes.TimestampNanos; +import org.apache.beam.sdk.schemas.FieldValueGetter; +import org.apache.beam.sdk.schemas.FieldValueSetter; +import org.apache.beam.sdk.schemas.FieldValueTypeInformation; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.sdk.schemas.SchemaUserTypeCreator; +import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; +import org.apache.beam.sdk.schemas.logicaltypes.OneOfType.Value; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.ConvertType; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.ConvertValueForGetter; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.ConvertValueForSetter; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.InjectPackageStrategy; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.TypeConversion; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.TypeConversionsFactory; +import org.apache.beam.sdk.schemas.utils.FieldValueTypeSupplier; +import org.apache.beam.sdk.schemas.utils.JavaBeanUtils; +import org.apache.beam.sdk.schemas.utils.ReflectUtils; +import org.apache.beam.sdk.schemas.utils.ReflectUtils.ClassWithSchema; +import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.ByteBuddy; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.description.type.TypeDescription.ForLoadedType; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.dynamic.DynamicType; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.dynamic.loading.ClassLoadingStrategy; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.dynamic.scaffold.InstrumentedType; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.Implementation; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.bytecode.ByteCodeAppender; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.bytecode.StackManipulation; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.bytecode.StackManipulation.Compound; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.bytecode.assign.Assigner; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.bytecode.assign.Assigner.Typing; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.bytecode.assign.TypeCasting; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.bytecode.member.MethodInvocation; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.implementation.bytecode.member.MethodReturn; +import org.apache.beam.vendor.bytebuddy.v1_9_3.net.bytebuddy.matcher.ElementMatchers; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.CaseFormat; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap; + +/** Code generation utilities to enable {@link ProtoRecordSchema}. */ +public class ProtoByteBuddyUtils { + private static final ByteBuddy BYTE_BUDDY = new ByteBuddy(); + private static TypeDescriptor BYTE_STRING_TYPE_DESCRIPTOR = + TypeDescriptor.of(ByteString.class); + private static TypeDescriptor PROTO_TIMESTAMP_TYPE_DESCRIPTOR = + TypeDescriptor.of(Timestamp.class); + private static TypeDescriptor PROTO_DURATION_TYPE_DESCRIPTOR = + TypeDescriptor.of(Duration.class); + private static TypeDescriptor PROTO_MESSAGE_ENUM_TYPE_DESCRIPTOR = + TypeDescriptor.of(ProtocolMessageEnum.class); + + private static final ForLoadedType BYTE_STRING_TYPE = new ForLoadedType(ByteString.class); + private static final ForLoadedType BYTE_ARRAY_TYPE = new ForLoadedType(byte[].class); + private static final ForLoadedType PROTO_ENUM_TYPE = new ForLoadedType(ProtocolMessageEnum.class); + private static final ForLoadedType INTEGER_TYPE = new ForLoadedType(Integer.class); + private static final ForLoadedType TIMESTAMP_NANOS_TYPE = new ForLoadedType(TimestampNanos.class); + private static final ForLoadedType DURATION_NANOS_TYPE = new ForLoadedType(DurationNanos.class); + + // The following proto types have special suffixes on the generated getters. + private static final Map PROTO_GETTER_SUFFIX = + ImmutableMap.of( + TypeName.ARRAY, "List", + TypeName.ITERABLE, "List", + TypeName.MAP, "Map"); + // By default proto getters always start with get. + private static final String DEFAULT_PROTO_GETTER_PREFIX = "get"; + + // The following proto types have special prefixes on the generated setters. + private static final Map PROTO_SETTER_PREFIX = + ImmutableMap.of( + TypeName.ARRAY, "addAll", + TypeName.ITERABLE, "addAll", + TypeName.MAP, "putAll"); + // The remaining proto types have setters that start with set. + private static final String DEFAULT_PROTO_SETTER_PREFIX = "set"; + + // Given a name and a type, generate the proto getter name. + static String protoGetterName(String name, FieldType fieldType) { + final String camel = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, name); + return DEFAULT_PROTO_GETTER_PREFIX + + camel + + PROTO_GETTER_SUFFIX.getOrDefault(fieldType.getTypeName(), ""); + } + + // Given a name and a type, generate the proto builder setter name. + static String protoSetterName(String name, FieldType fieldType) { + final String camel = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, name); + return protoSetterPrefix(fieldType) + camel; + } + + static String protoSetterPrefix(FieldType fieldType) { + return PROTO_SETTER_PREFIX.getOrDefault(fieldType.getTypeName(), DEFAULT_PROTO_SETTER_PREFIX); + } + + // Converts the Java type returned by a proto getter to the type that Row.getValue will return. + static class ProtoConvertType extends ConvertType { + ProtoConvertType(boolean returnRawValues) { + super(returnRawValues); + } + + @Override + public Type convert(TypeDescriptor typeDescriptor) { + if (typeDescriptor.equals(BYTE_STRING_TYPE_DESCRIPTOR) + || typeDescriptor.isSubtypeOf(BYTE_STRING_TYPE_DESCRIPTOR)) { + return byte[].class; + } else if (typeDescriptor.isSubtypeOf(PROTO_MESSAGE_ENUM_TYPE_DESCRIPTOR)) { + return Integer.class; + } else if (typeDescriptor.equals(PROTO_TIMESTAMP_TYPE_DESCRIPTOR) + || typeDescriptor.equals(PROTO_DURATION_TYPE_DESCRIPTOR)) { + return Row.class; + } else { + return super.convert(typeDescriptor); + } + } + } + + // Given a StackManipulation that reads a value from a proto (by invoking the getter), generate + // code to convert + // that into the type that Row.getValue is expected to return. + static class ProtoConvertValueForGetter extends ConvertValueForGetter { + ProtoConvertValueForGetter(StackManipulation readValue) { + super(readValue); + } + + @Override + protected ProtoTypeConversionsFactory getFactory() { + return new ProtoTypeConversionsFactory(); + } + + @Override + public StackManipulation convert(TypeDescriptor type) { + if (type.equals(BYTE_STRING_TYPE_DESCRIPTOR) + || type.isSubtypeOf(BYTE_STRING_TYPE_DESCRIPTOR)) { + // For ByteString values, return ByteString.toByteArray. + return new Compound( + readValue, + MethodInvocation.invoke( + BYTE_STRING_TYPE + .getDeclaredMethods() + .filter(ElementMatchers.named("toByteArray")) + .getOnly())); + } else if (type.isSubtypeOf(PROTO_MESSAGE_ENUM_TYPE_DESCRIPTOR)) { + // If the type is ProtocolMessageEnum, then return ProtocolMessageEnum.getNumber. + return new Compound( + readValue, + MethodInvocation.invoke( + PROTO_ENUM_TYPE + .getDeclaredMethods() + .filter( + ElementMatchers.named("getNumber").and(ElementMatchers.takesArguments(0))) + .getOnly()), + Assigner.DEFAULT.assign( + INTEGER_TYPE.asUnboxed().asGenericType(), + INTEGER_TYPE.asGenericType(), + Typing.STATIC)); + } else if (type.equals(PROTO_TIMESTAMP_TYPE_DESCRIPTOR)) { + // If the type is a proto timestamp, then convert it to the appropriate row. + return new Compound( + readValue, + MethodInvocation.invoke( + TIMESTAMP_NANOS_TYPE + .getDeclaredMethods() + .filter(ElementMatchers.named("toRow")) + .getOnly())); + } else if (type.equals(PROTO_DURATION_TYPE_DESCRIPTOR)) { + // If the type is a proto duration, then convert it to the appropriate row. + return new Compound( + readValue, + MethodInvocation.invoke( + DURATION_NANOS_TYPE + .getDeclaredMethods() + .filter(ElementMatchers.named("toRow")) + .getOnly())); + } else { + return super.convert(type); + } + } + } + + // Convert from the type returned by Row.getValue to the type expected by a proto builder setter. + static class ProtoConvertValueForSetter extends ConvertValueForSetter { + ProtoConvertValueForSetter(StackManipulation readValue) { + super(readValue); + } + + @Override + protected ProtoTypeConversionsFactory getFactory() { + return new ProtoTypeConversionsFactory(); + } + + @Override + public StackManipulation convert(TypeDescriptor type) { + if (type.isSubtypeOf(BYTE_STRING_TYPE_DESCRIPTOR)) { + // Convert a byte[] to a ByteString. + return new Compound( + readValue, + TypeCasting.to(BYTE_ARRAY_TYPE), + MethodInvocation.invoke( + BYTE_STRING_TYPE + .getDeclaredMethods() + .filter( + ElementMatchers.named("copyFrom") + .and(ElementMatchers.takesArguments(BYTE_ARRAY_TYPE))) + .getOnly())); + } else if (type.isSubtypeOf(PROTO_MESSAGE_ENUM_TYPE_DESCRIPTOR)) { + ForLoadedType loadedType = new ForLoadedType(type.getRawType()); + // Convert the stored number back to the enum constant. + return new Compound( + readValue, + Assigner.DEFAULT.assign( + INTEGER_TYPE.asBoxed().asGenericType(), + INTEGER_TYPE.asUnboxed().asGenericType(), + Typing.STATIC), + MethodInvocation.invoke( + loadedType + .getDeclaredMethods() + .filter( + ElementMatchers.named("forNumber") + .and(ElementMatchers.isStatic().and(ElementMatchers.takesArguments(1)))) + .getOnly())); + } else if (type.equals(PROTO_TIMESTAMP_TYPE_DESCRIPTOR)) { + // Convert to a proto timestamp. + return new Compound( + readValue, + MethodInvocation.invoke( + TIMESTAMP_NANOS_TYPE + .getDeclaredMethods() + .filter(ElementMatchers.named("toTimestamp")) + .getOnly())); + } else if (type.equals(PROTO_DURATION_TYPE_DESCRIPTOR)) { + // Convert to a proto duration. + return new Compound( + readValue, + MethodInvocation.invoke( + DURATION_NANOS_TYPE + .getDeclaredMethods() + .filter(ElementMatchers.named("toDuration")) + .getOnly())); + } else { + return super.convert(type); + } + } + } + + // A factory that is injected to allow injection of the above TypeConversion classes. + static class ProtoTypeConversionsFactory implements TypeConversionsFactory { + @Override + public TypeConversion createTypeConversion(boolean returnRawTypes) { + return new ProtoConvertType(returnRawTypes); + } + + @Override + public TypeConversion createGetterConversions(StackManipulation readValue) { + return new ProtoConvertValueForGetter(readValue); + } + + @Override + public TypeConversion createSetterConversions(StackManipulation readValue) { + return new ProtoConvertValueForSetter(readValue); + } + } + + // The list of getters for a class is cached, so we only create the classes the first time + // getSetters is called. + private static final Map> CACHED_GETTERS = + Maps.newConcurrentMap(); + + /** + * Return the list of {@link FieldValueGetter}s for a Java Bean class + * + *

The returned list is ordered by the order of fields in the schema. + */ + public static List getGetters( + Class clazz, + Schema schema, + FieldValueTypeSupplier fieldValueTypeSupplier, + TypeConversionsFactory typeConversionsFactory) { + Multimap methods = ReflectUtils.getMethodsMap(clazz); + return CACHED_GETTERS.computeIfAbsent( + ClassWithSchema.create(clazz, schema), + c -> { + List types = fieldValueTypeSupplier.get(clazz, schema); + return types.stream() + .map( + t -> + createGetter( + t, + typeConversionsFactory, + clazz, + methods, + schema.getField(t.getName()), + fieldValueTypeSupplier)) + .collect(Collectors.toList()); + }); + } + + private static FieldValueGetter createGetter( + FieldValueTypeInformation fieldValueTypeInformation, + TypeConversionsFactory typeConversionsFactory, + Class clazz, + Multimap methods, + Field field, + FieldValueTypeSupplier fieldValueTypeSupplier) { + if (field.getType().isLogicalType(OneOfType.IDENTIFIER)) { + OneOfType oneOfType = field.getType().getLogicalType(OneOfType.class); + + // The case accessor method in the proto is named getOneOfNameCase. + Method caseMethod = + getProtoGetter( + methods, + field.getName() + "_case", + FieldType.logicalType(oneOfType.getCaseEnumType())); + Map oneOfGetters = Maps.newHashMap(); + Map oneOfFieldTypes = + fieldValueTypeSupplier.get(clazz, oneOfType.getOneOfSchema()).stream() + .collect(Collectors.toMap(FieldValueTypeInformation::getName, f -> f)); + for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { + int protoFieldIndex = getFieldNumber(oneOfField.getType()); + FieldValueGetter oneOfFieldGetter = + createGetter( + oneOfFieldTypes.get(oneOfField.getName()), + typeConversionsFactory, + clazz, + methods, + oneOfField, + fieldValueTypeSupplier); + oneOfGetters.put(protoFieldIndex, oneOfFieldGetter); + } + return new OneOfFieldValueGetter(field.getName(), caseMethod, oneOfGetters, oneOfType); + } else { + return JavaBeanUtils.createGetter(fieldValueTypeInformation, typeConversionsFactory); + } + } + + private static Class getProtoGeneratedBuilder(Class clazz) { + String builderClassName = clazz.getName() + "$Builder"; + try { + return Class.forName(builderClassName); + } catch (ClassNotFoundException e) { + return null; + } + } + + static Method getProtoSetter(Multimap methods, String name, FieldType fieldType) { + final TypeDescriptor builderDescriptor = + TypeDescriptor.of(MessageLite.Builder.class); + return methods.get(protoSetterName(name, fieldType)).stream() + // Setter methods take only a single parameter. + .filter(m -> m.getParameterCount() == 1) + // For nested types, we don't use the version that takes a builder. + .filter( + m -> !TypeDescriptor.of(m.getGenericParameterTypes()[0]).isSubtypeOf(builderDescriptor)) + .findAny() + .orElseThrow(IllegalArgumentException::new); + } + + static Method getProtoGetter(Multimap methods, String name, FieldType fieldType) { + return methods.get(protoGetterName(name, fieldType)).stream() + .filter(m -> m.getParameterCount() == 0) + .findAny() + .orElseThrow(IllegalArgumentException::new); + } + + @Nullable + public static SchemaUserTypeCreator getBuilderCreator( + Class protoClass, Schema schema, FieldValueTypeSupplier fieldValueTypeSupplier) { + Class builderClass = getProtoGeneratedBuilder(protoClass); + if (builderClass == null) { + return null; + } + + List setters = Lists.newArrayListWithCapacity(schema.getFieldCount()); + Multimap methods = ReflectUtils.getMethodsMap(builderClass); + for (Field field : schema.getFields()) { + if (field.getType().isLogicalType(OneOfType.IDENTIFIER)) { + OneOfType oneOfType = field.getType().getLogicalType(OneOfType.class); + Map oneOfMethods = Maps.newHashMap(); + for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { + Method method = getProtoSetter(methods, oneOfField.getName(), oneOfField.getType()); + oneOfMethods.put(getFieldNumber(oneOfField.getType()), method); + } + setters.add(new OneOfFieldValueSetter(oneOfMethods, field.getName())); + } else { + Method method = getProtoSetter(methods, field.getName(), field.getType()); + setters.add( + JavaBeanUtils.createSetter( + FieldValueTypeInformation.forSetter(method, protoSetterPrefix(field.getType())), + new ProtoTypeConversionsFactory())); + } + } + + return createBuilderCreator(protoClass, builderClass, setters, schema); + } + + /** + * A getter for a oneof value. Ideally we would codegen this as well to avoid map lookups on each + * invocation. However generating switch statements with byte buddy is complicated, so for now + * we're using a map. + */ + static class OneOfFieldValueGetter + implements FieldValueGetter { + private final String name; + private final Method getCaseMethod; + private final Map> getterMethodMap; + private final OneOfType oneOfType; + + public OneOfFieldValueGetter( + String name, + Method getCaseMethod, + Map> getterMethodMap, + OneOfType oneOfType) { + this.name = name; + this.getCaseMethod = getCaseMethod; + this.getterMethodMap = getterMethodMap; + this.oneOfType = oneOfType; + } + + @Nullable + @Override + public Value get(ProtoT object) { + try { + EnumLite caseValue = (EnumLite) getCaseMethod.invoke(object); + if (caseValue.getNumber() == 0) { + return null; + } else { + Object value = getterMethodMap.get(caseValue.getNumber()).get(object); + return oneOfType.createValue( + oneOfType.getCaseEnumType().valueOf(caseValue.getNumber()), value); + } + } catch (IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException(e); + } + } + + @Override + public String name() { + return name; + } + } + + /** + * A setter for a OneOf value. Ideally we would codegen this, as this class requires a map lookup + * as well as a reflection-based method invocation - both of which can be expensive. However + * generating switch statements with ByteBuddy is a bit complicated, so for now we're doing it + * this way. + */ + static class OneOfFieldValueSetter + implements FieldValueSetter { + private final Map methods; + private final String name; + + OneOfFieldValueSetter(Map methods, String name) { + this.methods = methods; + this.name = name; + } + + @Override + public void set(BuilderT builder, OneOfType.Value oneOfValue) { + Method method = methods.get(oneOfValue.getCaseType().getValue()); + try { + method.invoke(builder, oneOfValue.getValue()); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException(e); + } + } + + @Override + public String name() { + return name; + } + } + + static SchemaUserTypeCreator createBuilderCreator( + Class protoClass, Class builderClass, List setters, Schema schema) { + try { + DynamicType.Builder builder = + BYTE_BUDDY + .with(new InjectPackageStrategy(builderClass)) + .subclass(Supplier.class) + .method(ElementMatchers.named("get")) + .intercept(new BuilderSupplier(protoClass)); + Supplier supplier = + builder + .make() + .load(ReflectHelpers.findClassLoader(), ClassLoadingStrategy.Default.INJECTION) + .getLoaded() + .getDeclaredConstructor() + .newInstance(); + return new ProtoCreatorFactory(supplier, setters); + } catch (InstantiationException + | IllegalAccessException + | NoSuchMethodException + | InvocationTargetException e) { + throw new RuntimeException( + "Unable to generate a creator for class " + builderClass + " with schema " + schema); + } + } + + // This is the class that actually creates a proto buffer. + static class ProtoCreatorFactory implements SchemaUserTypeCreator { + private final Supplier builderCreator; + private final List setters; + + public ProtoCreatorFactory( + Supplier builderCreator, List setters) { + this.builderCreator = builderCreator; + this.setters = setters; + } + + @Override + public Object create(Object... params) { + MessageLite.Builder builder = builderCreator.get(); + for (int i = 0; i < params.length; ++i) { + setters.get(i).set(builder, params[i]); + } + return builder.build(); + } + } + + // This is the implementation of a Supplier class that when invoked returns a builder for the + // specified protocol + // buffer. + static class BuilderSupplier implements Implementation { + private final Class protoClass; + + public BuilderSupplier(Class protoClass) { + this.protoClass = protoClass; + } + + @Override + public InstrumentedType prepare(InstrumentedType instrumentedType) { + return instrumentedType; + } + + @Override + public ByteCodeAppender appender(final Target implementationTarget) { + ForLoadedType loadedProto = new ForLoadedType(protoClass); + return (methodVisitor, implementationContext, instrumentedMethod) -> { + // this + method parameters. + int numLocals = 1 + instrumentedMethod.getParameters().size(); + + // Create the builder object by calling ProtoClass.newBuilder(). + StackManipulation stackManipulation = + new StackManipulation.Compound( + MethodInvocation.invoke( + loadedProto + .getDeclaredMethods() + .filter( + ElementMatchers.named("newBuilder") + .and(ElementMatchers.takesArguments(0))) + .getOnly()), + MethodReturn.REFERENCE); + StackManipulation.Size size = stackManipulation.apply(methodVisitor, implementationContext); + return new ByteCodeAppender.Size(size.getMaximalSize(), numLocals); + }; + } + } +} diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java new file mode 100644 index 000000000000..f37b7c7501e9 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import com.google.protobuf.DynamicMessage; +import java.util.List; +import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.schemas.FieldValueGetter; +import org.apache.beam.sdk.schemas.FieldValueTypeInformation; +import org.apache.beam.sdk.schemas.GetterBasedSchemaProvider; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaUserTypeCreator; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; + +@Experimental(Kind.SCHEMAS) +public class ProtoDynamicMessageSchema extends GetterBasedSchemaProvider { + private static final TypeDescriptor DYNAMIC_MESSAGE_TYPE_DESCRIPTOR = + TypeDescriptor.of(DynamicMessage.class); + + @Nullable + @Override + public Schema schemaFor(TypeDescriptor typeDescriptor) { + checkForDynamicType(typeDescriptor); + return ProtoSchemaTranslator.getSchema((Class) typeDescriptor.getRawType()); + } + + @Override + public List fieldValueGetters(Class targetClass, Schema schema) { + return null; + } + + @Override + public List fieldValueTypeInformations( + Class targetClass, Schema schema) { + List types = Lists.newArrayListWithCapacity(schema.getFieldCount()); + return null; + } + + @Override + public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema schema) { + return null; + } + + private void checkForDynamicType(TypeDescriptor typeDescriptor) { + if (!typeDescriptor.isSubtypeOf(DYNAMIC_MESSAGE_TYPE_DESCRIPTOR)) { + throw new IllegalArgumentException("ProtoDynamicMessageSchema only handles DynamicMessages."); + } + } +} diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoFieldOverlay.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoFieldOverlay.java new file mode 100644 index 000000000000..7f0b04c6d3b1 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoFieldOverlay.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.Message; +import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.schemas.FieldValueGetter; +import org.apache.beam.sdk.schemas.FieldValueSetter; + +@Experimental(Kind.SCHEMAS) +public abstract class ProtoFieldOverlay + implements FieldValueGetter, FieldValueSetter { + protected final FieldDescriptor fieldDescriptor; + private final String name; + + public ProtoFieldOverlay(FieldDescriptor fieldDescriptor, String name) { + this.fieldDescriptor = fieldDescriptor; + this.name = name; + } + + @Nullable + @Override + public abstract ValueT get(Message object); + + @Override + public abstract void set(Message.Builder builder, @Nullable ValueT value); + + @Override + public String name() { + return name; + } +} diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoFieldOverlays.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoFieldOverlays.java new file mode 100644 index 000000000000..b0e6559d5bf7 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoFieldOverlays.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.Message; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; + +public class ProtoFieldOverlays { + + /** Overlay for Protobuf primitive types. Primitive values are just passed through. */ + class PrimitiveOverlay extends ProtoFieldOverlay { + transient Object cached = null; + + PrimitiveOverlay(FieldDescriptor fieldDescriptor, String name) { + super(fieldDescriptor, name); + } + + @Override + public Object get(Message message) { + return MoreObjects.firstNonNull(cached, cached = message.getField(fieldDescriptor)); + } + + @Override + public void set(Message.Builder builder, Object value) { + builder.setField( + builder.getDescriptorForType().findFieldByNumber(fieldDescriptor.getNumber()), value); + } + } + + class BytesOverlay extends ProtoFieldOverlay { + transient byte[] cached = null; + + BytesOverlay(FieldDescriptor fieldDescriptor, String name) { + super(fieldDescriptor, name); + } + + @Override + public Object get(Message message) { + return MoreObjects.firstNonNull(cached, cached = message.getField(fieldDescriptor)); + } + + @Override + public void set(Message.Builder builder, Object value) { + builder.setField( + builder.getDescriptorForType().findFieldByNumber(fieldDescriptor.getNumber()), value); + } + } +} diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoRecordSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoRecordSchema.java new file mode 100644 index 000000000000..15e223db16d3 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoRecordSchema.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import static org.apache.beam.sdk.extensions.protobuf.ProtoByteBuddyUtils.getProtoGetter; + +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Message; +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.extensions.protobuf.ProtoByteBuddyUtils.ProtoTypeConversionsFactory; +import org.apache.beam.sdk.schemas.FieldValueGetter; +import org.apache.beam.sdk.schemas.FieldValueTypeInformation; +import org.apache.beam.sdk.schemas.GetterBasedSchemaProvider; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.SchemaUserTypeCreator; +import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; +import org.apache.beam.sdk.schemas.utils.FieldValueTypeSupplier; +import org.apache.beam.sdk.schemas.utils.JavaBeanUtils; +import org.apache.beam.sdk.schemas.utils.ReflectUtils; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap; + +/** + * A {@link SchemaProvider} for Protocol Buffer objects. + * + *

This provider is for statically compiled protocol buffer objects. It generates the full schema + * along with efficient conversions to and from the internal {@link Row} object. + */ +@Experimental(Kind.SCHEMAS) +public class ProtoRecordSchema extends GetterBasedSchemaProvider { + + private static final class ProtoClassFieldValueTypeSupplier implements FieldValueTypeSupplier { + @Override + public List get(Class clazz) { + throw new RuntimeException("Unexpected call."); + } + + @Override + public List get(Class clazz, Schema schema) { + Multimap methods = ReflectUtils.getMethodsMap(clazz); + List types = + Lists.newArrayListWithCapacity(schema.getFieldCount()); + for (Field field : schema.getFields()) { + if (field.getType().isLogicalType(OneOfType.IDENTIFIER)) { + // This is a OneOf. Look for the getters for each OneOf option. + OneOfType oneOfType = field.getType().getLogicalType(OneOfType.class); + Map oneOfTypes = Maps.newHashMap(); + for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { + Method method = getProtoGetter(methods, oneOfField.getName(), oneOfField.getType()); + oneOfTypes.put( + oneOfField.getName(), + FieldValueTypeInformation.forGetter(method).withName(field.getName())); + } + // Add an entry that encapsulates information about all possible getters. + types.add( + FieldValueTypeInformation.forOneOf( + field.getName(), field.getType().getNullable(), oneOfTypes) + .withName(field.getName())); + } else { + // This is a simple field. Add the getter. + Method method = getProtoGetter(methods, field.getName(), field.getType()); + types.add(FieldValueTypeInformation.forGetter(method).withName(field.getName())); + } + } + return types; + } + } + + @Nullable + @Override + public Schema schemaFor(TypeDescriptor typeDescriptor) { + checkForDynamicType(typeDescriptor); + return ProtoSchemaTranslator.getSchema((Class) typeDescriptor.getRawType()); + } + + @Override + public List fieldValueGetters(Class targetClass, Schema schema) { + return ProtoByteBuddyUtils.getGetters( + targetClass, + schema, + new ProtoClassFieldValueTypeSupplier(), + new ProtoTypeConversionsFactory()); + } + + @Override + public List fieldValueTypeInformations( + Class targetClass, Schema schema) { + return JavaBeanUtils.getFieldTypes(targetClass, schema, new ProtoClassFieldValueTypeSupplier()); + } + + @Override + public SchemaUserTypeCreator schemaTypeCreator(Class targetClass, Schema schema) { + SchemaUserTypeCreator creator = + ProtoByteBuddyUtils.getBuilderCreator( + targetClass, schema, new ProtoClassFieldValueTypeSupplier()); + if (creator == null) { + throw new RuntimeException("Cannot create creator for " + targetClass); + } + return creator; + } + + private void checkForDynamicType(TypeDescriptor typeDescriptor) { + if (typeDescriptor.getRawType().equals(DynamicMessage.class)) { + throw new RuntimeException( + "DynamicMessage is not allowed for the standard ProtoSchemaProvider, use ProtoDynamicMessageSchema instead."); + } + } +} diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoRecordSchemaTest.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoRecordSchemaTest.java new file mode 100644 index 000000000000..7f2723fc35b2 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoRecordSchemaTest.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_BOOL; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_INT32; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_PRIMITIVE; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_STRING; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_BOOL; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_INT32; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_PRIMITIVE; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_STRING; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.OUTER_ONEOF_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.OUTER_ONEOF_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.OUTER_ONEOF_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.PRIMITIVE_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.PRIMITIVE_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.PRIMITIVE_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_SCHEMA; +import static org.junit.Assert.assertEquals; + +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.EnumMessage; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.EnumMessage.Enum; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.MapPrimitive; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.Nested; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.OneOf; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.OuterOneOf; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.Primitive; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.RepeatPrimitive; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.WktMessage; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link ProtoRecordSchema}. */ +@RunWith(JUnit4.class) +public class ProtoRecordSchemaTest { + + @Test + public void testPrimitiveSchema() { + Schema schema = new ProtoRecordSchema().schemaFor(TypeDescriptor.of(Primitive.class)); + assertEquals(PRIMITIVE_SCHEMA, schema); + } + + @Test + public void testPrimitiveProtoToRow() { + SerializableFunction toRow = + new ProtoRecordSchema().toRowFunction(TypeDescriptor.of(Primitive.class)); + assertEquals(PRIMITIVE_ROW, toRow.apply(PRIMITIVE_PROTO)); + } + + @Test + public void testPrimitiveRowToProto() { + SerializableFunction fromRow = + new ProtoRecordSchema().fromRowFunction(TypeDescriptor.of(Primitive.class)); + assertEquals(PRIMITIVE_PROTO, fromRow.apply(PRIMITIVE_ROW)); + } + + @Test + public void testRepeatedSchema() { + Schema schema = new ProtoRecordSchema().schemaFor(TypeDescriptor.of(RepeatPrimitive.class)); + assertEquals(REPEATED_SCHEMA, schema); + } + + @Test + public void testRepeatedProtoToRow() { + SerializableFunction toRow = + new ProtoRecordSchema().toRowFunction(TypeDescriptor.of(RepeatPrimitive.class)); + assertEquals(REPEATED_ROW, toRow.apply(REPEATED_PROTO)); + } + + @Test + public void testRepeatedRowToProto() { + SerializableFunction fromRow = + new ProtoRecordSchema().fromRowFunction(TypeDescriptor.of(RepeatPrimitive.class)); + assertEquals(REPEATED_PROTO, fromRow.apply(REPEATED_ROW)); + } + + // Test map type + @Test + public void testMapSchema() { + Schema schema = new ProtoRecordSchema().schemaFor(TypeDescriptor.of(MapPrimitive.class)); + assertEquals(MAP_PRIMITIVE_SCHEMA, schema); + } + + @Test + public void testMapProtoToRow() { + SerializableFunction toRow = + new ProtoRecordSchema().toRowFunction(TypeDescriptor.of(MapPrimitive.class)); + assertEquals(MAP_PRIMITIVE_ROW, toRow.apply(MAP_PRIMITIVE_PROTO)); + } + + @Test + public void testMapRowToProto() { + SerializableFunction fromRow = + new ProtoRecordSchema().fromRowFunction(TypeDescriptor.of(MapPrimitive.class)); + assertEquals(MAP_PRIMITIVE_PROTO, fromRow.apply(MAP_PRIMITIVE_ROW)); + } + + @Test + public void testNestedSchema() { + Schema schema = new ProtoRecordSchema().schemaFor(TypeDescriptor.of(Nested.class)); + assertEquals(NESTED_SCHEMA, schema); + } + + @Test + public void testNestedProtoToRow() { + SerializableFunction toRow = + new ProtoRecordSchema().toRowFunction(TypeDescriptor.of(Nested.class)); + assertEquals(NESTED_ROW, toRow.apply(NESTED_PROTO)); + } + + @Test + public void testNestedRowToProto() { + SerializableFunction fromRow = + new ProtoRecordSchema().fromRowFunction(TypeDescriptor.of(Nested.class)); + assertEquals(NESTED_PROTO, fromRow.apply(NESTED_ROW)); + } + + @Test + public void testOneOfSchema() { + Schema schema = new ProtoRecordSchema().schemaFor(TypeDescriptor.of(OneOf.class)); + assertEquals(ONEOF_SCHEMA, schema); + } + + @Test + public void testOneOfProtoToRow() { + SerializableFunction toRow = + new ProtoRecordSchema().toRowFunction(TypeDescriptor.of(OneOf.class)); + assertEquals(ONEOF_ROW_INT32, toRow.apply(ONEOF_PROTO_INT32)); + assertEquals(ONEOF_ROW_BOOL, toRow.apply(ONEOF_PROTO_BOOL)); + assertEquals(ONEOF_ROW_STRING, toRow.apply(ONEOF_PROTO_STRING)); + assertEquals(ONEOF_ROW_PRIMITIVE, toRow.apply(ONEOF_PROTO_PRIMITIVE)); + } + + @Test + public void testOneOfRowToProto() { + SerializableFunction fromRow = + new ProtoRecordSchema().fromRowFunction(TypeDescriptor.of(OneOf.class)); + assertEquals(ONEOF_PROTO_INT32, fromRow.apply(ONEOF_ROW_INT32)); + assertEquals(ONEOF_PROTO_BOOL, fromRow.apply(ONEOF_ROW_BOOL)); + assertEquals(ONEOF_PROTO_STRING, fromRow.apply(ONEOF_ROW_STRING)); + assertEquals(ONEOF_PROTO_PRIMITIVE, fromRow.apply(ONEOF_ROW_PRIMITIVE)); + } + + @Test + public void testOuterOneOfSchema() { + Schema schema = new ProtoRecordSchema().schemaFor(TypeDescriptor.of(OuterOneOf.class)); + assertEquals(OUTER_ONEOF_SCHEMA, schema); + } + + @Test + public void testOuterOneOfProtoToRow() { + SerializableFunction toRow = + new ProtoRecordSchema().toRowFunction(TypeDescriptor.of(OuterOneOf.class)); + assertEquals(OUTER_ONEOF_ROW, toRow.apply(OUTER_ONEOF_PROTO)); + } + + @Test + public void testOuterOneOfRowToProto() { + SerializableFunction fromRow = + new ProtoRecordSchema().fromRowFunction(TypeDescriptor.of(OuterOneOf.class)); + assertEquals(OUTER_ONEOF_PROTO, fromRow.apply(OUTER_ONEOF_ROW)); + } + + private static EnumerationType ENUM_TYPE = + EnumerationType.create(ImmutableMap.of("ZERO", 0, "TWO", 2, "THREE", 3)); + private static final Schema ENUM_SCHEMA = + Schema.builder() + .addField("enum", withFieldNumber(FieldType.logicalType(ENUM_TYPE).withNullable(true), 1)) + .build(); + private static final Row ENUM_ROW = + Row.withSchema(ENUM_SCHEMA).addValues(ENUM_TYPE.valueOf("TWO")).build(); + private static final EnumMessage ENUM_PROTO = EnumMessage.newBuilder().setEnum(Enum.TWO).build(); + + @Test + public void testEnumSchema() { + Schema schema = new ProtoRecordSchema().schemaFor(TypeDescriptor.of(EnumMessage.class)); + assertEquals(ENUM_SCHEMA, schema); + } + + @Test + public void testEnumProtoToRow() { + SerializableFunction toRow = + new ProtoRecordSchema().toRowFunction(TypeDescriptor.of(EnumMessage.class)); + assertEquals(ENUM_ROW, toRow.apply(ENUM_PROTO)); + } + + @Test + public void testEnumRowToProto() { + SerializableFunction fromRow = + new ProtoRecordSchema().fromRowFunction(TypeDescriptor.of(EnumMessage.class)); + assertEquals(ENUM_PROTO, fromRow.apply(ENUM_ROW)); + } + + @Test + public void testWktMessageSchema() { + Schema schema = new ProtoRecordSchema().schemaFor(TypeDescriptor.of(WktMessage.class)); + assertEquals(WKT_MESSAGE_SCHEMA, schema); + } + + @Test + public void testWktProtoToRow() { + SerializableFunction toRow = + new ProtoRecordSchema().toRowFunction(TypeDescriptor.of(WktMessage.class)); + assertEquals(WKT_MESSAGE_ROW, toRow.apply(WKT_MESSAGE_PROTO)); + } + + @Test + public void testWktRowToProto() { + SerializableFunction fromRow = + new ProtoRecordSchema().fromRowFunction(TypeDescriptor.of(WktMessage.class)); + assertEquals(WKT_MESSAGE_PROTO, fromRow.apply(WKT_MESSAGE_ROW)); + } +}