Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,9 @@ public interface FieldValueGetter<ObjectT, ValueT> extends Serializable {
@Nullable
ValueT get(ObjectT object);

default @Nullable Object getRaw(ObjectT object) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe getUntyped is a more appropriate name here?

Also could you clarify why we need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for having a look @TheNeuralBit 🙏

getRaw() was based on a conversation with @reuvenlax.

getValues() is maybe poorly named - might be better called getRawValues. What you're looking for is probably the getBaseValues() method.
getValues is mostly used in code that knows exactly what it's doing for optimization purposes. It goes along with the attachValues method, which is similarly tricky to use. It's there to enable 0-copy code, but not necessarily intended for general consumption.

RowWithGetters.getValues() returns the "raw" unmodified result of the getters:

public List<Object> getValues() {
  return getters.stream().map(g -> g.get(getterTarget)).collect(Collectors.toList());
}

As I am pushing down the transformation of the getter result into the getter itself, I needed a way to bypass that in order to maintain the current semantics of getValues(). Let me know if the name makes sense given that context.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that makes sense, thanks. Could you add some of this context in a comment there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TheNeuralBit I've opened a new PR to add the missing comment, sorry for the delay.
#21982

return get(object);
}

String name();
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,27 @@
*/
package org.apache.beam.sdk.schemas;

import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.Schema.LogicalType;
import org.apache.beam.sdk.schemas.Schema.TypeName;
import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType;
import org.apache.beam.sdk.schemas.logicaltypes.OneOfType;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Collections2;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.checkerframework.checker.nullness.qual.Nullable;

/**
Expand Down Expand Up @@ -53,10 +67,10 @@ private class ToRowWithValueGetters<T> implements SerializableFunction<T, Row> {
public ToRowWithValueGetters(Schema schema) {
this.schema = schema;
// Since we know that this factory is always called from inside the lambda with the same
// schema,
// return a caching factory that caches the first value seen for each class. This prevents
// having to lookup the getter list each time createGetters is called.
this.getterFactory = new CachingFactory<>(GetterBasedSchemaProvider.this::fieldValueGetters);
// schema, return a caching factory that caches the first value seen for each class. This
// prevents having to lookup the getter list each time createGetters is called.
this.getterFactory =
RowValueGettersFactory.of(GetterBasedSchemaProvider.this::fieldValueGetters);
}

@Override
Expand Down Expand Up @@ -115,4 +129,245 @@ public int hashCode() {
public boolean equals(@Nullable Object obj) {
return obj != null && this.getClass() == obj.getClass();
}

private static class RowValueGettersFactory implements Factory<List<FieldValueGetter>> {
private final Factory<List<FieldValueGetter>> gettersFactory;
private final Factory<List<FieldValueGetter>> cachingGettersFactory;

static Factory<List<FieldValueGetter>> of(Factory<List<FieldValueGetter>> gettersFactory) {
return new RowValueGettersFactory(gettersFactory).cachingGettersFactory;
}

RowValueGettersFactory(Factory<List<FieldValueGetter>> gettersFactory) {
this.gettersFactory = gettersFactory;
this.cachingGettersFactory = new CachingFactory<>(this);
}

@Override
public List<FieldValueGetter> create(Class<?> clazz, Schema schema) {
List<FieldValueGetter> getters = gettersFactory.create(clazz, schema);
List<FieldValueGetter> rowGetters = new ArrayList<>(getters.size());
for (int i = 0; i < getters.size(); i++) {
rowGetters.add(rowValueGetter(getters.get(i), schema.getField(i).getType()));
}
return rowGetters;
}

static boolean needsConversion(FieldType type) {
TypeName typeName = type.getTypeName();
return typeName.equals(TypeName.ROW)
|| typeName.isLogicalType()
|| ((typeName.equals(TypeName.ARRAY) || typeName.equals(TypeName.ITERABLE))
&& needsConversion(type.getCollectionElementType()))
|| (typeName.equals(TypeName.MAP)
&& (needsConversion(type.getMapKeyType())
|| needsConversion(type.getMapValueType())));
}

FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type) {
TypeName typeName = type.getTypeName();
if (!needsConversion(type)) {
return base;
}
if (typeName.equals(TypeName.ROW)) {
return new GetRow(base, type.getRowSchema(), cachingGettersFactory);
} else if (typeName.equals(TypeName.ARRAY)) {
FieldType elementType = type.getCollectionElementType();
return elementType.getTypeName().equals(TypeName.ROW)
? new GetEagerCollection(base, converter(elementType))
: new GetCollection(base, converter(elementType));
} else if (typeName.equals(TypeName.ITERABLE)) {
return new GetIterable(base, converter(type.getCollectionElementType()));
} else if (typeName.equals(TypeName.MAP)) {
return new GetMap(base, converter(type.getMapKeyType()), converter(type.getMapValueType()));
} else if (type.isLogicalType(OneOfType.IDENTIFIER)) {
OneOfType oneOfType = type.getLogicalType(OneOfType.class);
Schema oneOfSchema = oneOfType.getOneOfSchema();
Map<String, Integer> values = oneOfType.getCaseEnumType().getValuesMap();

Map<Integer, FieldValueGetter> converters = Maps.newHashMapWithExpectedSize(values.size());
for (Map.Entry<String, Integer> kv : values.entrySet()) {
FieldType fieldType = oneOfSchema.getField(kv.getKey()).getType();
FieldValueGetter converter = converter(fieldType);
converters.put(kv.getValue(), converter);
}

return new GetOneOf(base, converters, oneOfType);
} else if (typeName.isLogicalType()) {
return new GetLogicalInputType(base, type.getLogicalType());
}
return base;
}

FieldValueGetter converter(FieldType type) {
return rowValueGetter(IDENTITY, type);
}

static class GetRow extends Converter<Object> {
final Schema schema;
final Factory<List<FieldValueGetter>> factory;

GetRow(FieldValueGetter getter, Schema schema, Factory<List<FieldValueGetter>> factory) {
super(getter);
this.schema = schema;
this.factory = factory;
}

@Override
Object convert(Object value) {
return Row.withSchema(schema).withFieldValueGetters(factory, value);
}
}

static class GetEagerCollection extends Converter<Collection> {
final FieldValueGetter converter;

GetEagerCollection(FieldValueGetter getter, FieldValueGetter converter) {
super(getter);
this.converter = converter;
}

@Override
Object convert(Collection collection) {
List newList = new ArrayList(collection.size());
for (Object obj : collection) {
newList.add(converter.get(obj));
}
return newList;
}
}

static class GetCollection extends Converter<Collection> {
final FieldValueGetter converter;

GetCollection(FieldValueGetter getter, FieldValueGetter converter) {
super(getter);
this.converter = converter;
}

@Override
Object convert(Collection collection) {
if (collection instanceof List) {
// For performance reasons if the input is a list, make sure that we produce a list.
// Otherwise, Row forwarding is forced to physically copy the collection into a new List
// object.
return Lists.transform((List) collection, converter::get);
} else {
return Collections2.transform(collection, converter::get);
}
}
}

static class GetIterable extends Converter<Iterable> {
final FieldValueGetter converter;

GetIterable(FieldValueGetter getter, FieldValueGetter converter) {
super(getter);
this.converter = converter;
}

@Override
Object convert(Iterable value) {
return Iterables.transform(value, converter::get);
}
}

static class GetMap extends Converter<Map<?, ?>> {
final FieldValueGetter keyConverter;
final FieldValueGetter valueConverter;

GetMap(
FieldValueGetter getter, FieldValueGetter keyConverter, FieldValueGetter valueConverter) {
super(getter);
this.keyConverter = keyConverter;
this.valueConverter = valueConverter;
}

@Override
Object convert(Map<?, ?> value) {
Map returnMap = Maps.newHashMapWithExpectedSize(value.size());
for (Map.Entry<?, ?> entry : value.entrySet()) {
returnMap.put(keyConverter.get(entry.getKey()), valueConverter.get(entry.getValue()));
}
return returnMap;
}
}

static class GetLogicalInputType extends Converter<Object> {
final LogicalType logicalType;

GetLogicalInputType(FieldValueGetter getter, LogicalType logicalType) {
super(getter);
this.logicalType = logicalType;
}

@Override
Object convert(Object value) {
// Getters are assumed to return the base type.
return logicalType.toInputType(value);
}
}

static class GetOneOf extends Converter<OneOfType.Value> {
final OneOfType oneOfType;
final Map<Integer, FieldValueGetter> converters;

GetOneOf(
FieldValueGetter getter, Map<Integer, FieldValueGetter> converters, OneOfType oneOfType) {
super(getter);
this.converters = converters;
this.oneOfType = oneOfType;
}

@Override
Object convert(OneOfType.Value value) {
EnumerationType.Value caseType = value.getCaseType();
FieldValueGetter converter = converters.get(caseType.getValue());
checkState(converter != null, "Missing OneOf converter for case %s.", caseType);
return oneOfType.createValue(caseType, converter.get(value.getValue()));
}
}

abstract static class Converter<T> implements FieldValueGetter {
final FieldValueGetter getter;

public Converter(FieldValueGetter getter) {
this.getter = getter;
}

abstract Object convert(T value);

@Override
public @Nullable Object get(Object object) {
T value = (T) getter.get(object);
if (value == null) {
return null;
}
return convert(value);
}

@Override
public @Nullable Object getRaw(Object object) {
return getter.getRaw(object);
}

@Override
public String name() {
return getter.name();
}
}

private static final FieldValueGetter IDENTITY =
new FieldValueGetter() {
@Override
public @Nullable Object get(Object object) {
return object;
}

@Override
public String name() {
return null;
}
};
}
}
Loading