Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Modifier;
import java.util.BitSet;
import java.util.List;
import java.util.Map;
Expand All @@ -32,7 +33,6 @@
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.Schema.TypeName;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.ByteBuddy;
Expand All @@ -47,11 +47,9 @@
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.FixedValue;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.Implementation;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.ByteCodeAppender;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.ByteCodeAppender.Size;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.Duplication;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.StackManipulation;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.StackManipulation.Compound;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.TypeCreation;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.collection.ArrayFactory;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.member.FieldAccess;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.member.MethodInvocation;
import org.apache.beam.vendor.bytebuddy.v1_10_8.net.bytebuddy.implementation.bytecode.member.MethodReturn;
Expand Down Expand Up @@ -99,68 +97,100 @@
@Experimental(Kind.SCHEMAS)
public abstract class RowCoderGenerator {
private static final ByteBuddy BYTE_BUDDY = new ByteBuddy();
private static final ForLoadedType CODER_TYPE = new ForLoadedType(Coder.class);
private static final ForLoadedType LIST_CODER_TYPE = new ForLoadedType(ListCoder.class);
private static final ForLoadedType ITERABLE_CODER_TYPE = new ForLoadedType(IterableCoder.class);
private static final ForLoadedType MAP_CODER_TYPE = new ForLoadedType(MapCoder.class);
private static final BitSetCoder NULL_LIST_CODER = BitSetCoder.of();
private static final VarIntCoder VAR_INT_CODER = VarIntCoder.of();
private static final ForLoadedType NULLABLE_CODER = new ForLoadedType(NullableCoder.class);

private static final String CODERS_FIELD_NAME = "FIELD_CODERS";

// A map of primitive types -> StackManipulations to create their coders.
private static final Map<TypeName, StackManipulation> CODER_MAP;

// Cache for Coder class that are already generated.
private static Map<UUID, Coder<Row>> generatedCoders = Maps.newConcurrentMap();

static {
// Initialize the CODER_MAP with the StackManipulations to create the primitive coders.
// Assumes that each class contains a static of() constructor method.
CODER_MAP = Maps.newHashMap();
for (Map.Entry<TypeName, Coder> entry : SchemaCoder.CODER_MAP.entrySet()) {
StackManipulation stackManipulation =
MethodInvocation.invoke(
new ForLoadedType(entry.getValue().getClass())
.getDeclaredMethods()
.filter(ElementMatchers.named("of"))
.getOnly());
CODER_MAP.putIfAbsent(entry.getKey(), stackManipulation);
}
}
private static final Map<UUID, Coder<Row>> GENERATED_CODERS = Maps.newConcurrentMap();

@SuppressWarnings("unchecked")
public static Coder<Row> generate(Schema schema) {
// Using ConcurrentHashMap::computeIfAbsent here would deadlock in case of nested
// coders. Using HashMap::computeIfAbsent generates ConcurrentModificationExceptions in Java 11.
Coder<Row> rowCoder = generatedCoders.get(schema.getUUID());
Coder<Row> rowCoder = GENERATED_CODERS.get(schema.getUUID());
if (rowCoder == null) {
TypeDescription.Generic coderType =
TypeDescription.Generic.Builder.parameterizedType(Coder.class, Row.class).build();
DynamicType.Builder<Coder> builder =
(DynamicType.Builder<Coder>) BYTE_BUDDY.subclass(coderType);
builder = createComponentCoders(schema, builder);
builder = implementMethods(schema, builder);

Coder[] componentCoders = new Coder[schema.getFieldCount()];
for (int i = 0; i < schema.getFieldCount(); ++i) {
// We use withNullable(false) as nulls are handled by the RowCoder and the individual
// component coders therefore do not need to handle nulls.
componentCoders[i] =
SchemaCoder.coderForFieldType(schema.getField(i).getType().withNullable(false));
}

builder =
builder.defineField(
CODERS_FIELD_NAME, Coder[].class, Visibility.PRIVATE, FieldManifestation.FINAL);

builder =
builder
.defineConstructor(Modifier.PUBLIC)
.withParameters(Coder[].class)
.intercept(new GeneratedCoderConstructor());

try {
rowCoder =
builder
.make()
.load(Coder.class.getClassLoader(), ClassLoadingStrategy.Default.INJECTION)
.getLoaded()
.getDeclaredConstructor()
.newInstance();
.getDeclaredConstructor(Coder[].class)
.newInstance((Object) componentCoders);
} catch (InstantiationException
| IllegalAccessException
| NoSuchMethodException
| InvocationTargetException e) {
throw new RuntimeException("Unable to generate coder for schema " + schema);
throw new RuntimeException("Unable to generate coder for schema " + schema, e);
}
generatedCoders.put(schema.getUUID(), rowCoder);
GENERATED_CODERS.put(schema.getUUID(), rowCoder);
}
return rowCoder;
}

private static class GeneratedCoderConstructor implements Implementation {
@Override
public InstrumentedType prepare(InstrumentedType instrumentedType) {
return instrumentedType;
}

@Override
public ByteCodeAppender appender(final Target implementationTarget) {
return (methodVisitor, implementationContext, instrumentedMethod) -> {
int numLocals = 1 + instrumentedMethod.getParameters().size();
StackManipulation stackManipulation =
new StackManipulation.Compound(
// Call the base constructor.
MethodVariableAccess.loadThis(),
Duplication.SINGLE,
MethodInvocation.invoke(
new ForLoadedType(Coder.class)
.getDeclaredMethods()
.filter(
ElementMatchers.isConstructor().and(ElementMatchers.takesArguments(0)))
.getOnly()),
// Store the list of Coders as a member variable.
MethodVariableAccess.REFERENCE.loadFrom(1),
FieldAccess.forField(
implementationTarget
.getInstrumentedType()
.getDeclaredFields()
.filter(ElementMatchers.named(CODERS_FIELD_NAME))
.getOnly())
.write(),
MethodReturn.VOID);
StackManipulation.Size size = stackManipulation.apply(methodVisitor, implementationContext);
return new Size(size.getMaximalSize(), numLocals);
};
}
}

private static DynamicType.Builder<Coder> implementMethods(
Schema schema, DynamicType.Builder<Coder> builder) {
boolean hasNullableFields =
Expand All @@ -185,6 +215,7 @@ public ByteCodeAppender appender(Target implementationTarget) {
StackManipulation manipulation =
new StackManipulation.Compound(
// Array of coders.
MethodVariableAccess.loadThis(),
FieldAccess.forField(
implementationContext
.getInstrumentedType()
Expand Down Expand Up @@ -272,6 +303,7 @@ public ByteCodeAppender appender(Target implementationTarget) {
.filter(ElementMatchers.named("getSchema"))
.getOnly()),
// Array of coders.
MethodVariableAccess.loadThis(),
FieldAccess.forField(
implementationContext
.getInstrumentedType()
Expand Down Expand Up @@ -313,7 +345,8 @@ static Row decodeDelegate(Schema schema, Coder[] coders, InputStream inputStream
if (nullFields.get(i)) {
fieldValues.add(null);
} else {
fieldValues.add(coders[i].decode(inputStream));
Object fieldValue = coders[i].decode(inputStream);
fieldValues.add(fieldValue);
}
}
}
Expand All @@ -329,120 +362,4 @@ static Row decodeDelegate(Schema schema, Coder[] coders, InputStream inputStream
return Row.withSchema(schema).attachValues(fieldValues).build();
}
}

private static DynamicType.Builder<Coder> createComponentCoders(
Schema schema, DynamicType.Builder<Coder> builder) {
List<StackManipulation> componentCoders =
Lists.newArrayListWithCapacity(schema.getFieldCount());
for (int i = 0; i < schema.getFieldCount(); i++) {
// We use withNullable(false) as nulls are handled by the RowCoder and the individual
// component coders therefore do not need to handle nulls.
componentCoders.add(getCoder(schema.getField(i).getType().withNullable(false)));
}

return builder
// private static final Coder[] FIELD_CODERS;
.defineField(
CODERS_FIELD_NAME,
Coder[].class,
Visibility.PRIVATE,
Ownership.STATIC,
FieldManifestation.FINAL)
// Static initializer.
.initializer(
(methodVisitor, implementationContext, instrumentedMethod) -> {
StackManipulation manipulation =
new StackManipulation.Compound(
// Initialize the array of coders.
ArrayFactory.forType(CODER_TYPE.asGenericType()).withValues(componentCoders),
FieldAccess.forField(
implementationContext
.getInstrumentedType()
.getDeclaredFields()
.filter(ElementMatchers.named(CODERS_FIELD_NAME))
.getOnly())
.write());
StackManipulation.Size size =
manipulation.apply(methodVisitor, implementationContext);
return new ByteCodeAppender.Size(
size.getMaximalSize(), instrumentedMethod.getStackSize());
});
}

private static StackManipulation getCoder(Schema.FieldType fieldType) {
if (TypeName.LOGICAL_TYPE.equals(fieldType.getTypeName())) {
return getCoder(fieldType.getLogicalType().getBaseType());
} else if (TypeName.ARRAY.equals(fieldType.getTypeName())) {
return listCoder(fieldType.getCollectionElementType());
} else if (TypeName.ITERABLE.equals(fieldType.getTypeName())) {
return iterableCoder(fieldType.getCollectionElementType());
}
if (TypeName.MAP.equals(fieldType.getTypeName())) {;
return mapCoder(fieldType.getMapKeyType(), fieldType.getMapValueType());
} else if (TypeName.ROW.equals(fieldType.getTypeName())) {
checkState(fieldType.getRowSchema().getUUID() != null);
Coder<Row> nestedCoder = generate(fieldType.getRowSchema());
return rowCoder(nestedCoder.getClass());
} else {
StackManipulation primitiveCoder = coderForPrimitiveType(fieldType.getTypeName());

if (fieldType.getNullable()) {
primitiveCoder =
new Compound(
primitiveCoder,
MethodInvocation.invoke(
NULLABLE_CODER
.getDeclaredMethods()
.filter(ElementMatchers.named("of"))
.getOnly()));
}

return primitiveCoder;
}
}

private static StackManipulation listCoder(Schema.FieldType fieldType) {
StackManipulation componentCoder = getCoder(fieldType);
return new Compound(
componentCoder,
MethodInvocation.invoke(
LIST_CODER_TYPE.getDeclaredMethods().filter(ElementMatchers.named("of")).getOnly()));
}

private static StackManipulation iterableCoder(Schema.FieldType fieldType) {
StackManipulation componentCoder = getCoder(fieldType);
return new Compound(
componentCoder,
MethodInvocation.invoke(
ITERABLE_CODER_TYPE
.getDeclaredMethods()
.filter(ElementMatchers.named("of"))
.getOnly()));
}

static StackManipulation coderForPrimitiveType(Schema.TypeName typeName) {
return CODER_MAP.get(typeName);
}

static StackManipulation mapCoder(Schema.FieldType keyType, Schema.FieldType valueType) {
StackManipulation keyCoder = getCoder(keyType);
StackManipulation valueCoder = getCoder(valueType);
return new Compound(
keyCoder,
valueCoder,
MethodInvocation.invoke(
MAP_CODER_TYPE.getDeclaredMethods().filter(ElementMatchers.named("of")).getOnly()));
}

static StackManipulation rowCoder(Class coderClass) {
ForLoadedType loadedType = new ForLoadedType(coderClass);
return new Compound(
TypeCreation.of(loadedType),
Duplication.SINGLE,
MethodInvocation.invoke(
loadedType
.getDeclaredMethods()
.filter(ElementMatchers.isConstructor().and(ElementMatchers.takesArguments(0)))
.getOnly()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.beam.sdk.annotations.Experimental.Kind;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.Schema.TypeName;
import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType;
import org.apache.beam.sdk.schemas.logicaltypes.OneOfType;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.Row;
Expand Down Expand Up @@ -126,21 +127,25 @@ private <ValueT> ValueT fromValue(
valueType,
typeFactory);
} else {
if (type.getTypeName().isLogicalType()
&& OneOfType.IDENTIFIER.equals(type.getLogicalType().getIdentifier())) {
if (type.isLogicalType(OneOfType.IDENTIFIER)) {
OneOfType oneOfType = type.getLogicalType(OneOfType.class);
OneOfType.Value oneOfValue = oneOfType.toInputType((Row) value);
EnumerationType oneOfEnum = oneOfType.getCaseEnumType();
OneOfType.Value oneOfValue = (OneOfType.Value) value;
FieldValueTypeInformation oneOfFieldValueTypeInformation =
checkNotNull(
fieldValueTypeInformation.getOneOfTypes().get(oneOfValue.getCaseType().toString()));
fieldValueTypeInformation
.getOneOfTypes()
.get(oneOfEnum.toString(oneOfValue.getCaseType())));
Object fromValue =
fromValue(
oneOfValue.getFieldType(),
oneOfType.getFieldType(oneOfValue),
oneOfValue.getValue(),
oneOfFieldValueTypeInformation.getRawType(),
oneOfFieldValueTypeInformation,
typeFactory);
return (ValueT) oneOfType.createValue(oneOfValue.getCaseType(), fromValue);
} else if (type.getTypeName().isLogicalType()) {
return (ValueT) type.getLogicalType().toBaseType(value);
}
return value;
}
Expand Down
Loading