diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java index af48879c8b81..dd2279d65ecc 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueGetter.java @@ -33,5 +33,9 @@ public interface FieldValueGetter extends Serializable { @Nullable ValueT get(ObjectT object); + default @Nullable Object getRaw(ObjectT object) { + return get(object); + } + String name(); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java index b58f24add76e..10986a0de179 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java @@ -17,13 +17,27 @@ */ package org.apache.beam.sdk.schemas; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; + +import java.util.ArrayList; +import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Objects; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.Schema.LogicalType; +import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Collections2; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -53,10 +67,10 @@ private class ToRowWithValueGetters implements SerializableFunction { public ToRowWithValueGetters(Schema schema) { this.schema = schema; // Since we know that this factory is always called from inside the lambda with the same - // schema, - // return a caching factory that caches the first value seen for each class. This prevents - // having to lookup the getter list each time createGetters is called. - this.getterFactory = new CachingFactory<>(GetterBasedSchemaProvider.this::fieldValueGetters); + // schema, return a caching factory that caches the first value seen for each class. This + // prevents having to lookup the getter list each time createGetters is called. + this.getterFactory = + RowValueGettersFactory.of(GetterBasedSchemaProvider.this::fieldValueGetters); } @Override @@ -115,4 +129,245 @@ public int hashCode() { public boolean equals(@Nullable Object obj) { return obj != null && this.getClass() == obj.getClass(); } + + private static class RowValueGettersFactory implements Factory> { + private final Factory> gettersFactory; + private final Factory> cachingGettersFactory; + + static Factory> of(Factory> gettersFactory) { + return new RowValueGettersFactory(gettersFactory).cachingGettersFactory; + } + + RowValueGettersFactory(Factory> gettersFactory) { + this.gettersFactory = gettersFactory; + this.cachingGettersFactory = new CachingFactory<>(this); + } + + @Override + public List create(Class clazz, Schema schema) { + List getters = gettersFactory.create(clazz, schema); + List rowGetters = new ArrayList<>(getters.size()); + for (int i = 0; i < getters.size(); i++) { + rowGetters.add(rowValueGetter(getters.get(i), schema.getField(i).getType())); + } + return rowGetters; + } + + static boolean needsConversion(FieldType type) { + TypeName typeName = type.getTypeName(); + return typeName.equals(TypeName.ROW) + || typeName.isLogicalType() + || ((typeName.equals(TypeName.ARRAY) || typeName.equals(TypeName.ITERABLE)) + && needsConversion(type.getCollectionElementType())) + || (typeName.equals(TypeName.MAP) + && (needsConversion(type.getMapKeyType()) + || needsConversion(type.getMapValueType()))); + } + + FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type) { + TypeName typeName = type.getTypeName(); + if (!needsConversion(type)) { + return base; + } + if (typeName.equals(TypeName.ROW)) { + return new GetRow(base, type.getRowSchema(), cachingGettersFactory); + } else if (typeName.equals(TypeName.ARRAY)) { + FieldType elementType = type.getCollectionElementType(); + return elementType.getTypeName().equals(TypeName.ROW) + ? new GetEagerCollection(base, converter(elementType)) + : new GetCollection(base, converter(elementType)); + } else if (typeName.equals(TypeName.ITERABLE)) { + return new GetIterable(base, converter(type.getCollectionElementType())); + } else if (typeName.equals(TypeName.MAP)) { + return new GetMap(base, converter(type.getMapKeyType()), converter(type.getMapValueType())); + } else if (type.isLogicalType(OneOfType.IDENTIFIER)) { + OneOfType oneOfType = type.getLogicalType(OneOfType.class); + Schema oneOfSchema = oneOfType.getOneOfSchema(); + Map values = oneOfType.getCaseEnumType().getValuesMap(); + + Map converters = Maps.newHashMapWithExpectedSize(values.size()); + for (Map.Entry kv : values.entrySet()) { + FieldType fieldType = oneOfSchema.getField(kv.getKey()).getType(); + FieldValueGetter converter = converter(fieldType); + converters.put(kv.getValue(), converter); + } + + return new GetOneOf(base, converters, oneOfType); + } else if (typeName.isLogicalType()) { + return new GetLogicalInputType(base, type.getLogicalType()); + } + return base; + } + + FieldValueGetter converter(FieldType type) { + return rowValueGetter(IDENTITY, type); + } + + static class GetRow extends Converter { + final Schema schema; + final Factory> factory; + + GetRow(FieldValueGetter getter, Schema schema, Factory> factory) { + super(getter); + this.schema = schema; + this.factory = factory; + } + + @Override + Object convert(Object value) { + return Row.withSchema(schema).withFieldValueGetters(factory, value); + } + } + + static class GetEagerCollection extends Converter { + final FieldValueGetter converter; + + GetEagerCollection(FieldValueGetter getter, FieldValueGetter converter) { + super(getter); + this.converter = converter; + } + + @Override + Object convert(Collection collection) { + List newList = new ArrayList(collection.size()); + for (Object obj : collection) { + newList.add(converter.get(obj)); + } + return newList; + } + } + + static class GetCollection extends Converter { + final FieldValueGetter converter; + + GetCollection(FieldValueGetter getter, FieldValueGetter converter) { + super(getter); + this.converter = converter; + } + + @Override + Object convert(Collection collection) { + if (collection instanceof List) { + // For performance reasons if the input is a list, make sure that we produce a list. + // Otherwise, Row forwarding is forced to physically copy the collection into a new List + // object. + return Lists.transform((List) collection, converter::get); + } else { + return Collections2.transform(collection, converter::get); + } + } + } + + static class GetIterable extends Converter { + final FieldValueGetter converter; + + GetIterable(FieldValueGetter getter, FieldValueGetter converter) { + super(getter); + this.converter = converter; + } + + @Override + Object convert(Iterable value) { + return Iterables.transform(value, converter::get); + } + } + + static class GetMap extends Converter> { + final FieldValueGetter keyConverter; + final FieldValueGetter valueConverter; + + GetMap( + FieldValueGetter getter, FieldValueGetter keyConverter, FieldValueGetter valueConverter) { + super(getter); + this.keyConverter = keyConverter; + this.valueConverter = valueConverter; + } + + @Override + Object convert(Map value) { + Map returnMap = Maps.newHashMapWithExpectedSize(value.size()); + for (Map.Entry entry : value.entrySet()) { + returnMap.put(keyConverter.get(entry.getKey()), valueConverter.get(entry.getValue())); + } + return returnMap; + } + } + + static class GetLogicalInputType extends Converter { + final LogicalType logicalType; + + GetLogicalInputType(FieldValueGetter getter, LogicalType logicalType) { + super(getter); + this.logicalType = logicalType; + } + + @Override + Object convert(Object value) { + // Getters are assumed to return the base type. + return logicalType.toInputType(value); + } + } + + static class GetOneOf extends Converter { + final OneOfType oneOfType; + final Map converters; + + GetOneOf( + FieldValueGetter getter, Map converters, OneOfType oneOfType) { + super(getter); + this.converters = converters; + this.oneOfType = oneOfType; + } + + @Override + Object convert(OneOfType.Value value) { + EnumerationType.Value caseType = value.getCaseType(); + FieldValueGetter converter = converters.get(caseType.getValue()); + checkState(converter != null, "Missing OneOf converter for case %s.", caseType); + return oneOfType.createValue(caseType, converter.get(value.getValue())); + } + } + + abstract static class Converter implements FieldValueGetter { + final FieldValueGetter getter; + + public Converter(FieldValueGetter getter) { + this.getter = getter; + } + + abstract Object convert(T value); + + @Override + public @Nullable Object get(Object object) { + T value = (T) getter.get(object); + if (value == null) { + return null; + } + return convert(value); + } + + @Override + public @Nullable Object getRaw(Object object) { + return getter.getRaw(object); + } + + @Override + public String name() { + return getter.name(); + } + } + + private static final FieldValueGetter IDENTITY = + new FieldValueGetter() { + @Override + public @Nullable Object get(Object object) { + return object; + } + + @Override + public String name() { + return null; + } + }; + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java index d105c744dd91..96d1e04ddf0d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java @@ -17,24 +17,18 @@ */ package org.apache.beam.sdk.values; -import java.util.Collection; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.stream.Collectors; +import java.util.TreeMap; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.schemas.Factory; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; -import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.TypeName; -import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Collections2; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -50,97 +44,43 @@ "rawtypes" }) public class RowWithGetters extends Row { - private final Factory> fieldValueGetterFactory; private final Object getterTarget; private final List getters; - - private final Map cachedCollections = Maps.newHashMap(); - private final Map cachedIterables = Maps.newHashMap(); - private final Map cachedMaps = Maps.newHashMap(); + private Map cache; RowWithGetters( Schema schema, Factory> getterFactory, Object getterTarget) { super(schema); - this.fieldValueGetterFactory = getterFactory; this.getterTarget = getterTarget; - this.getters = fieldValueGetterFactory.create(getterTarget.getClass(), schema); + this.getters = getterFactory.create(getterTarget.getClass(), schema); } @Override @SuppressWarnings({"TypeParameterUnusedInFormals", "unchecked"}) public @Nullable T getValue(int fieldIdx) { Field field = getSchema().getField(fieldIdx); - FieldType type = field.getType(); - Object fieldValue = getters.get(fieldIdx).get(getterTarget); - if (fieldValue == null && !field.getType().getNullable()) { - throw new RuntimeException("Null value set on non-nullable field " + field); - } - return fieldValue != null ? getValue(type, fieldValue, fieldIdx) : null; - } + boolean cacheField = cacheFieldType(field); - private Collection getCollectionValue(FieldType elementType, Object fieldValue) { - Collection collection = (Collection) fieldValue; - if (collection instanceof List) { - // For performance reasons if the input is a list, make sure that we produce a list. Otherwise - // Row forwarding - // is forced to physically copy the collection into a new List object. - return Lists.transform((List) collection, v -> getValue(elementType, v, null)); - } else { - return Collections2.transform(collection, v -> getValue(elementType, v, null)); + if (cacheField && cache == null) { + cache = new TreeMap<>(); } - } - private Iterable getIterableValue(FieldType elementType, Object fieldValue) { - Iterable iterable = (Iterable) fieldValue; - // Wrap the iterable to avoid having to materialize the entire collection. - return Iterables.transform(iterable, v -> getValue(elementType, v, null)); - } + Object fieldValue = + cacheField + ? cache.computeIfAbsent(fieldIdx, idx -> getters.get(idx).get(getterTarget)) + : getters.get(fieldIdx).get(getterTarget); - private Map getMapValue(FieldType keyType, FieldType valueType, Map fieldValue) { - Map returnMap = Maps.newHashMap(); - for (Map.Entry entry : fieldValue.entrySet()) { - returnMap.put( - getValue(keyType, entry.getKey(), null), getValue(valueType, entry.getValue(), null)); + if (fieldValue == null && !field.getType().getNullable()) { + throw new RuntimeException("Null value set on non-nullable field " + field); } - return returnMap; + return (T) fieldValue; } - @SuppressWarnings({"TypeParameterUnusedInFormals", "unchecked"}) - private T getValue(FieldType type, Object fieldValue, @Nullable Integer cacheKey) { - if (type.getTypeName().equals(TypeName.ROW)) { - return (T) new RowWithGetters(type.getRowSchema(), fieldValueGetterFactory, fieldValue); - } else if (type.getTypeName().equals(TypeName.ARRAY)) { - return cacheKey != null - ? (T) - cachedCollections.computeIfAbsent( - cacheKey, i -> getCollectionValue(type.getCollectionElementType(), fieldValue)) - : (T) getCollectionValue(type.getCollectionElementType(), fieldValue); - } else if (type.getTypeName().equals(TypeName.ITERABLE)) { - return cacheKey != null - ? (T) - cachedIterables.computeIfAbsent( - cacheKey, i -> getIterableValue(type.getCollectionElementType(), fieldValue)) - : (T) getIterableValue(type.getCollectionElementType(), fieldValue); - } else if (type.getTypeName().equals(TypeName.MAP)) { - Map map = (Map) fieldValue; - return cacheKey != null - ? (T) - cachedMaps.computeIfAbsent( - cacheKey, i -> getMapValue(type.getMapKeyType(), type.getMapValueType(), map)) - : (T) getMapValue(type.getMapKeyType(), type.getMapValueType(), map); - } else { - if (type.isLogicalType(OneOfType.IDENTIFIER)) { - OneOfType oneOfType = type.getLogicalType(OneOfType.class); - OneOfType.Value oneOfValue = (OneOfType.Value) fieldValue; - Object convertedOneOfField = - getValue(oneOfType.getFieldType(oneOfValue), oneOfValue.getValue(), null); - return (T) oneOfType.createValue(oneOfValue.getCaseType(), convertedOneOfField); - } else if (type.getTypeName().isLogicalType()) { - // Getters are assumed to return the base type. - return (T) type.getLogicalType().toInputType(fieldValue); - } - return (T) fieldValue; - } + private boolean cacheFieldType(Field field) { + TypeName typeName = field.getType().getTypeName(); + return typeName.equals(TypeName.MAP) + || typeName.equals(TypeName.ARRAY) + || typeName.equals(TypeName.ITERABLE); } @Override @@ -150,7 +90,11 @@ public int getFieldCount() { @Override public List getValues() { - return getters.stream().map(g -> g.get(getterTarget)).collect(Collectors.toList()); + List rawValues = new ArrayList<>(getters.size()); + for (FieldValueGetter getter : getters) { + rawValues.add(getter.getRaw(getterTarget)); + } + return rawValues; } public List getGetters() {