diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index 48c708c75fab..b8af5acfc855 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -499,6 +499,7 @@ public static class Builder { @Nullable private Factory> fieldValueGetterFactory; @Nullable private Object getterTarget; private Schema schema; + private boolean collectionHandledByGetter = false; Builder(Schema schema) { this.schema = schema; @@ -554,6 +555,12 @@ public Builder withFieldValueGetters( return this; } + /** The FieldValueGetters will handle the conversion for Arrays, Maps and Rows. */ + public Builder withFieldValueGettersHandleCollections(boolean collectionHandledByGetter) { + this.collectionHandledByGetter = collectionHandledByGetter; + return this; + } + private List verify(Schema schema, List values) { List verifiedValues = Lists.newArrayListWithCapacity(values.size()); if (schema.getFieldCount() != values.size()) { @@ -754,7 +761,11 @@ public Row build() { return new RowWithStorage(schema, storageValues); } else if (fieldValueGetterFactory != null) { checkState(getterTarget != null, "getters require withGetterTarget."); - return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); + if (collectionHandledByGetter) { + return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); + } else { + return new RowWithGettersCachedCollections(schema, fieldValueGetterFactory, getterTarget); + } } else { return new RowWithStorage(schema, Collections.emptyList()); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java index e50f818c1f67..8a5c11c58866 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.values; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -27,9 +26,6 @@ import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.Schema.TypeName; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; /** * A Concrete subclass of {@link Row} that delegates to a set of provided {@link FieldValueGetter}s. @@ -39,12 +35,9 @@ * the appropriate fields from the POJO. */ public class RowWithGetters extends Row { - private final Factory> fieldValueGetterFactory; - private final Object getterTarget; - private final List getters; - - private final Map cachedLists = Maps.newHashMap(); - private final Map cachedMaps = Maps.newHashMap(); + final Factory> fieldValueGetterFactory; + final Object getterTarget; + final List getters; RowWithGetters( Schema schema, Factory> getterFactory, Object getterTarget) { @@ -67,44 +60,9 @@ public T getValue(int fieldIdx) { return fieldValue != null ? getValue(type, fieldValue, fieldIdx) : null; } - private List getListValue(FieldType elementType, Object fieldValue) { - Iterable iterable = (Iterable) fieldValue; - List list = Lists.newArrayList(); - for (Object o : iterable) { - list.add(getValue(elementType, o, null)); - } - return list; - } - - private Map getMapValue(FieldType keyType, FieldType valueType, Map fieldValue) { - Map returnMap = Maps.newHashMap(); - for (Map.Entry entry : fieldValue.entrySet()) { - returnMap.put( - getValue(keyType, entry.getKey(), null), getValue(valueType, entry.getValue(), null)); - } - return returnMap; - } - @SuppressWarnings({"TypeParameterUnusedInFormals", "unchecked"}) - private T getValue(FieldType type, Object fieldValue, @Nullable Integer cacheKey) { - if (type.getTypeName().equals(TypeName.ROW)) { - return (T) new RowWithGetters(type.getRowSchema(), fieldValueGetterFactory, fieldValue); - } else if (type.getTypeName().equals(TypeName.ARRAY)) { - return cacheKey != null - ? (T) - cachedLists.computeIfAbsent( - cacheKey, i -> getListValue(type.getCollectionElementType(), fieldValue)) - : (T) getListValue(type.getCollectionElementType(), fieldValue); - } else if (type.getTypeName().equals(TypeName.MAP)) { - Map map = (Map) fieldValue; - return cacheKey != null - ? (T) - cachedMaps.computeIfAbsent( - cacheKey, i -> getMapValue(type.getMapKeyType(), type.getMapValueType(), map)) - : (T) getMapValue(type.getMapKeyType(), type.getMapValueType(), map); - } else { - return (T) fieldValue; - } + protected T getValue(FieldType type, Object fieldValue, @Nullable Integer cacheKey) { + return (T) fieldValue; } @Override diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGettersCachedCollections.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGettersCachedCollections.java new file mode 100644 index 000000000000..cbbcc7619857 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGettersCachedCollections.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.values; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import javax.annotation.Nullable; +import org.apache.beam.sdk.schemas.Factory; +import org.apache.beam.sdk.schemas.FieldValueGetter; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; + +/** + * A Concrete subclass of {@link Row} that delegates to a set of provided {@link FieldValueGetter}s. + * This is a special version of {@link RowWithGetters} that cached the map and list collection. + * + *

This allows us to have {@link Row} objects for which the actual storage is in another object. + * For example, the user's type may be a POJO, in which case the provided getters will simple read + * the appropriate fields from the POJO. + */ +public class RowWithGettersCachedCollections extends RowWithGetters { + private final Map cachedLists = Maps.newHashMap(); + private final Map cachedMaps = Maps.newHashMap(); + + RowWithGettersCachedCollections( + Schema schema, Factory> getterFactory, Object getterTarget) { + super(schema, getterFactory, getterTarget); + } + + private List getListValue(FieldType elementType, Object fieldValue) { + Iterable iterable = (Iterable) fieldValue; + List list = Lists.newArrayList(); + for (Object o : iterable) { + list.add(getValue(elementType, o, null)); + } + return list; + } + + private Map getMapValue(FieldType keyType, FieldType valueType, Map fieldValue) { + Map returnMap = Maps.newHashMap(); + for (Map.Entry entry : fieldValue.entrySet()) { + returnMap.put( + getValue(keyType, entry.getKey(), null), getValue(valueType, entry.getValue(), null)); + } + return returnMap; + } + + @SuppressWarnings({"TypeParameterUnusedInFormals", "unchecked"}) + @Override + protected T getValue(FieldType type, Object fieldValue, @Nullable Integer cacheKey) { + if (type.getTypeName().equals(TypeName.ROW)) { + return (T) + new RowWithGettersCachedCollections( + type.getRowSchema(), fieldValueGetterFactory, fieldValue); + } else if (type.getTypeName().equals(TypeName.ARRAY)) { + return cacheKey != null + ? (T) + cachedLists.computeIfAbsent( + cacheKey, i -> getListValue(type.getCollectionElementType(), fieldValue)) + : (T) getListValue(type.getCollectionElementType(), fieldValue); + } else if (type.getTypeName().equals(TypeName.MAP)) { + Map map = (Map) fieldValue; + return cacheKey != null + ? (T) + cachedMaps.computeIfAbsent( + cacheKey, i -> getMapValue(type.getMapKeyType(), type.getMapValueType(), map)) + : (T) getMapValue(type.getMapKeyType(), type.getMapValueType(), map); + } else { + return (T) fieldValue; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null) { + return false; + } + if (o instanceof RowWithGettersCachedCollections) { + RowWithGettersCachedCollections other = (RowWithGettersCachedCollections) o; + return Objects.equals(getSchema(), other.getSchema()) + && Objects.equals(getterTarget, other.getterTarget); + } else if (o instanceof Row) { + return super.equals(o); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(getSchema(), getterTarget); + } +} diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoFieldOverlay.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoFieldOverlay.java new file mode 100644 index 000000000000..6cddc2756358 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoFieldOverlay.java @@ -0,0 +1,525 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Message; +import com.google.protobuf.Timestamp; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.FieldValueGetter; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.joda.time.Instant; + +/** + * Protobuf ProtoFieldOverlay is the interface that each implementation needs to implement to handle + * a specific field types. + */ +@Experimental(Experimental.Kind.SCHEMAS) +public interface ProtoFieldOverlay extends FieldValueGetter { + + ValueT convertGetObject(FieldDescriptor fieldDescriptor, Object object); + + /** Convert the Row field and set it on the overlayed field of the message. */ + void set(Message.Builder object, ValueT value); + + Object convertSetObject(FieldDescriptor fieldDescriptor, Object value); + + /** Return the Beam Schema Field of this overlayed field. */ + Schema.Field getSchemaField(); + + abstract class ProtoFieldOverlayBase implements ProtoFieldOverlay { + + protected int number; + + private Schema.Field field; + + FieldDescriptor getFieldDescriptor(Message message) { + return message.getDescriptorForType().findFieldByNumber(number); + } + + FieldDescriptor getFieldDescriptor(Message.Builder message) { + return message.getDescriptorForType().findFieldByNumber(number); + } + + protected void setField(Schema.Field field) { + this.field = field; + } + + ProtoFieldOverlayBase(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + // this.fieldDescriptor = fieldDescriptor; + this.number = fieldDescriptor.getNumber(); + } + + @Override + public String name() { + return field.getName(); + } + + @Override + public Schema.Field getSchemaField() { + return field; + } + } + + /** Overlay for Protobuf primitive types. Primitive values are just passed through. */ + class PrimitiveOverlay extends ProtoFieldOverlayBase { + PrimitiveOverlay(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + // this.fieldDescriptor = fieldDescriptor; + super(protoSchema, fieldDescriptor); + setField( + Schema.Field.of( + fieldDescriptor.getName(), + ProtoSchema.convertType(fieldDescriptor.getType()) + .withMetadata(protoSchema.convertOptions(fieldDescriptor)))); + } + + @Override + public Object get(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertGetObject(fieldDescriptor, message.getField(fieldDescriptor)); + } + + @Override + public Object convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + return object; + } + + @Override + public void set(Message.Builder message, Object value) { + message.setField(getFieldDescriptor(message), value); + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** + * Overlay for Bytes. Protobuf Bytes are natively represented as ByteStrings that requires special + * handling for byte[] of size 0. + */ + class BytesOverlay extends PrimitiveOverlay { + BytesOverlay(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + super(protoSchema, fieldDescriptor); + } + + @Override + public Object convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + // return object; + return ((ByteString) object).toByteArray(); + } + + @Override + public void set(Message.Builder message, Object value) { + if (value != null && ((byte[]) value).length > 0) { + // Protobuf messages BYTES doesn't like empty bytes?! + FieldDescriptor fieldDescriptor = message.getDescriptorForType().findFieldByNumber(number); + message.setField(fieldDescriptor, convertSetObject(fieldDescriptor, value)); + } + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + if (value != null) { + return ByteString.copyFrom((byte[]) value); + } + return null; + } + } + + /** + * Overlay handler for the Well Known Type "Wrapper". These wrappers make it possible to have + * nullable primitives. + */ + class WrapperOverlay extends ProtoFieldOverlayBase { + private ProtoFieldOverlay value; + + WrapperOverlay(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + super(protoSchema, fieldDescriptor); + FieldDescriptor valueDescriptor = fieldDescriptor.getMessageType().findFieldByName("value"); + this.value = protoSchema.createFieldLayer(valueDescriptor, false); + setField( + Schema.Field.of( + fieldDescriptor.getName(), value.getSchemaField().getType().withNullable(true))); + } + + @Override + public ValueT get(Message message) { + if (message.hasField(getFieldDescriptor(message))) { + Message wrapper = (Message) message.getField(getFieldDescriptor(message)); + return (ValueT) value.get(wrapper); + } + return null; + } + + @Override + public ValueT convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + return (ValueT) object; + } + + @Override + public void set(Message.Builder message, ValueT value) { + if (value != null) { + DynamicMessage.Builder builder = + DynamicMessage.newBuilder(getFieldDescriptor(message).getMessageType()); + this.value.set(builder, value); + message.setField(getFieldDescriptor(message), builder.build()); + } + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** + * Overlay handler for the Well Known Type "Timestamp". This wrappers converts from a single Row + * DATETIME and a protobuf "Timestamp" messsage. + */ + class TimestampOverlay extends ProtoFieldOverlayBase { + TimestampOverlay(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + super(protoSchema, fieldDescriptor); + setField( + Schema.Field.of( + fieldDescriptor.getName(), + Schema.FieldType.DATETIME.withMetadata( + protoSchema.convertOptions(fieldDescriptor))) + .withNullable(true)); + } + + @Override + public Instant get(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + if (message.hasField(fieldDescriptor)) { + Message wrapper = (Message) message.getField(fieldDescriptor); + return convertGetObject(fieldDescriptor, wrapper); + } + return null; + } + + @Override + public Instant convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + Message timestamp = (Message) object; + Descriptors.Descriptor timestampFieldDescriptor = timestamp.getDescriptorForType(); + return new Instant( + (Long) timestamp.getField(timestampFieldDescriptor.findFieldByName("seconds")) * 1000 + + (Integer) timestamp.getField(timestampFieldDescriptor.findFieldByName("nanos")) + / 1000000); + } + + @Override + public void set(Message.Builder message, Instant value) { + if (value != null) { + long totalMillis = value.getMillis(); + long seconds = totalMillis / 1000; + int ns = (int) (totalMillis % 1000 * 1000000); + Timestamp timestamp = Timestamp.newBuilder().setSeconds(seconds).setNanos(ns).build(); + message.setField(getFieldDescriptor(message), timestamp); + } + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** This overlay converts a nested Message into a nested Row. */ + class MessageOverlay extends ProtoFieldOverlayBase { + private final SerializableFunction toRowFunction; + private final SerializableFunction fromRowFunction; + + MessageOverlay(ProtoSchema rootProtoSchema, FieldDescriptor fieldDescriptor) { + super(rootProtoSchema, fieldDescriptor); + + ProtoSchema protoSchema = + ProtoSchema.newBuilder(rootProtoSchema).forDescriptor(fieldDescriptor.getMessageType()); + SchemaCoder schemaCoder = protoSchema.getSchemaCoder(); + toRowFunction = schemaCoder.getToRowFunction(); + fromRowFunction = schemaCoder.getFromRowFunction(); + setField( + Schema.Field.of( + fieldDescriptor.getName(), + Schema.FieldType.row(protoSchema.getSchema()) + .withMetadata(protoSchema.convertOptions(fieldDescriptor)) + .withNullable(true))); + } + + @Override + public Object get(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + if (message.hasField(fieldDescriptor)) { + return convertGetObject(fieldDescriptor, message.getField(fieldDescriptor)); + } + return null; + } + + @Override + public Object convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + return toRowFunction.apply(object); + } + + @Override + public void set(Message.Builder message, Object value) { + if (value != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertSetObject(fieldDescriptor, value)); + } + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + return fromRowFunction.apply(value); + } + } + + /** + * Proto has a well defined way of storing maps, by having a Message with two fields, named "key" + * and "value" in a repeatable field. This overlay translates between Row.map and the Protobuf + * map. + */ + class MapOverlay extends ProtoFieldOverlayBase { + private ProtoFieldOverlay key; + private ProtoFieldOverlay value; + + MapOverlay(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + super(protoSchema, fieldDescriptor); + key = + protoSchema.createFieldLayer( + fieldDescriptor.getMessageType().findFieldByName("key"), false); + value = + protoSchema.createFieldLayer( + fieldDescriptor.getMessageType().findFieldByName("value"), false); + setField( + Schema.Field.of( + fieldDescriptor.getName(), + Schema.FieldType.map( + key.getSchemaField().getType(), + value + .getSchemaField() + .getType() + .withMetadata(protoSchema.convertOptions(fieldDescriptor))) + .withNullable(true))); + } + + @Override + public Map get(Message message) { + List list = (List) message.getField(getFieldDescriptor(message)); + if (list.size() == 0) { + return null; + } + Map rowMap = new HashMap(); + list.forEach( + entry -> { + Message entryMessage = (Message) entry; + Descriptors.Descriptor entryDescriptor = entryMessage.getDescriptorForType(); + FieldDescriptor keyFieldDescriptor = entryDescriptor.findFieldByName("key"); + FieldDescriptor valueFieldDescriptor = entryDescriptor.findFieldByName("value"); + rowMap.put( + key.convertGetObject(keyFieldDescriptor, entryMessage.getField(keyFieldDescriptor)), + this.value.convertGetObject( + valueFieldDescriptor, entryMessage.getField(valueFieldDescriptor))); + }); + return rowMap; + } + + @Override + public Map convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + throw new RuntimeException("?"); + } + + @Override + public void set(Message.Builder message, Map map) { + if (map != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + List messageMap = new ArrayList(); + map.forEach( + (k, v) -> { + DynamicMessage.Builder builder = + DynamicMessage.newBuilder(fieldDescriptor.getMessageType()); + FieldDescriptor keyFieldDescriptor = + fieldDescriptor.getMessageType().findFieldByName("key"); + builder.setField( + keyFieldDescriptor, this.key.convertSetObject(keyFieldDescriptor, k)); + FieldDescriptor valueFieldDescriptor = + fieldDescriptor.getMessageType().findFieldByName("value"); + builder.setField( + valueFieldDescriptor, value.convertSetObject(valueFieldDescriptor, v)); + messageMap.add(builder.build()); + }); + message.setField(fieldDescriptor, messageMap); + } + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** + * This overlay handles repeatable fields. It handles the Array conversion, but delegates the + * conversion of the individual elements to an embedded overlay. + */ + class ArrayOverlay extends ProtoFieldOverlayBase { + private ProtoFieldOverlay element; + + ArrayOverlay(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + super(protoSchema, fieldDescriptor); + this.element = protoSchema.createFieldLayer(fieldDescriptor, false); + setField( + Schema.Field.of( + fieldDescriptor.getName(), + Schema.FieldType.array( + element + .getSchemaField() + .getType() + .withMetadata(protoSchema.convertOptions(fieldDescriptor))) + .withNullable(true))); + } + + @Override + public List get(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + List list = (List) message.getField(fieldDescriptor); + if (list.size() == 0) { + return null; + } + List arrayList = new ArrayList<>(); + list.forEach( + entry -> { + arrayList.add(element.convertGetObject(fieldDescriptor, entry)); + }); + return arrayList; + } + + @Override + public List convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + throw new RuntimeException("?"); + } + + @Override + public void set(Message.Builder message, List list) { + if (list != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + List targetList = new ArrayList(); + list.forEach( + (e) -> { + targetList.add(element.convertSetObject(fieldDescriptor, e)); + }); + message.setField(fieldDescriptor, targetList); + } + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** Enum overlay handles the conversion between a string and a ProtoBuf Enum. */ + class EnumOverlay extends ProtoFieldOverlayBase { + + EnumOverlay(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + super(protoSchema, fieldDescriptor); + setField( + Schema.Field.of( + fieldDescriptor.getName(), + Schema.FieldType.STRING.withMetadata(protoSchema.convertOptions(fieldDescriptor)))); + } + + @Override + public Object get(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertGetObject(fieldDescriptor, message.getField(fieldDescriptor)); + } + + @Override + public Object convertGetObject(FieldDescriptor fieldDescriptor, Object in) { + return in.toString(); + } + + @Override + public void set(Message.Builder message, Object value) { + // builder.setField(fieldDescriptor, + // convertSetObject(row.getString(fieldDescriptor.getName()))); + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertSetObject(fieldDescriptor, value)); + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + Descriptors.EnumDescriptor enumType = fieldDescriptor.getEnumType(); + return enumType.findValueByName(value.toString()); + } + } + + /** + * This overlay handles nullable fields. If a primitive field needs to be nullable this overlay is + * wrapped around the original overlay. + */ + class NullableOverlay extends ProtoFieldOverlayBase { + + private ProtoFieldOverlay fieldOverlay; + + NullableOverlay( + ProtoSchema protoSchema, + FieldDescriptor fieldDescriptor, + ProtoFieldOverlay fieldOverlay) { + super(protoSchema, fieldDescriptor); + this.fieldOverlay = fieldOverlay; + setField(fieldOverlay.getSchemaField().withNullable(true)); + } + + @Override + public Object get(Message message) { + if (message.hasField(getFieldDescriptor(message))) { + return fieldOverlay.get(message); + } + return null; + } + + @Override + public Object convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + throw new RuntimeException("Value conversion should never be allowed in nullable fields"); + } + + @Override + public void set(Message.Builder message, Object value) { + if (value != null) { + fieldOverlay.set(message, value); + } + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + throw new RuntimeException("Value conversion should never be allowed in nullable fields"); + } + } +} diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchema.java new file mode 100644 index 000000000000..351da3ff9f41 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchema.java @@ -0,0 +1,568 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import com.google.protobuf.DescriptorProtos; +import com.google.protobuf.Descriptors; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Message; +import com.google.protobuf.UnknownFieldSet; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.Factory; +import org.apache.beam.sdk.schemas.FieldValueGetter; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; + +/** + * ProtoSchema is a top level anchor point. It makes sure it can recreate the complete schema and + * overlay with just the Message raw type or if it's a DynamicMessage with the serialised + * Descriptor. + * + *

ProtoDomain is an integral part of a ProtoSchema, it it contains all the information needed to + * iterpret and reconstruct messages. + * + *

    + *
  • Protobuf oneOf fields are mapped to nullable fields and flattened into the parent row. + *
  • Protobuf primitives are mapped to it's non nullable counter part. + *
  • Protobuf maps are mapped to nullable maps, where empty maps are mapped to the null value. + *
  • Protobuf repeatables are mapped to nullable arrays, where empty arrays are mapped to the + * null value. + *
  • Protobuf enums are mapped to non-nullable string values. + *
+ * + *

Protobuf Well Know Types are handled by the Beam Schema system. Beam knows of the following + * Well Know Types: + * + *

    + *
  • google.protobuf.Timestamp maps to a nullable Field.DATATIME + *
  • google.protobuf.StringValue maps to a nullable Field.STRING + *
  • google.protobuf.DoubleValue maps to a nullable Field.DOUBLE + *
  • google.protobuf.FloatValue maps to a nullable Field.FLOAT + *
  • google.protobuf.BytesValue maps to a nullable Field.BYTES + *
  • google.protobuf.BoolValue maps to a nullable Field.BOOL + *
  • google.protobuf.Int64Value maps to a nullable Field.INT64 + *
  • google.protobuf.Int32Value maps to a nullable Field.INT32 + *
  • google.protobuf.UInt64Value maps to a nullable Field.INT64 + *
  • google.protobuf.UInt32Value maps to a nullable Field.INT32 + *
+ */ +@Experimental(Experimental.Kind.SCHEMAS) +public class ProtoSchema implements Serializable { + public static final long serialVersionUID = 1L; + private static final ProtoDomain STATIC_COMPILED_DOMAIN = new ProtoDomain(); + private static Map globalSchemaCache = new HashMap<>(); + private final Class rawType; + private final Map typeMapping; + private final ProtoDomain domain; + private transient Descriptors.Descriptor descriptor; + private transient SchemaCoder schemaCoder; + private transient Method fnNewBuilder; + private transient ArrayList getters; + + private ProtoSchema( + Class rawType, + Descriptors.Descriptor descriptor, + ProtoDomain domain, + Map overlayClasses) { + this.rawType = rawType; + this.descriptor = descriptor; + this.typeMapping = overlayClasses; + this.domain = domain; + init(); + } + + /** + * Create a new ProtoSchema Builder with the static compiled proto domain. This domain references + * only statically compiled Java Protobuf messages. + */ + public static Builder newBuilder() { + return new Builder(STATIC_COMPILED_DOMAIN); + } + + /** + * Create a new ProtoSchema Builder with a specific proto domain. It does not contain any messages + * of the static domain. A Domain is used for grouping different messages that belong together. + * Creating different schema builders with the same domain is safe. The resulting Protobuf + * messages created from the same domain with be equal. + */ + public static Builder newBuilder(ProtoDomain protoDomain) { + return new Builder(protoDomain); + } + + static Builder newBuilder(ProtoSchema protoSchema) { + return new Builder(protoSchema.domain).addTypeMapping(protoSchema.typeMapping); + } + + static ProtoSchema fromSchema(Schema schema) { + return globalSchemaCache.get(schema.getUUID()); + } + + static Schema.FieldType convertType(Descriptors.FieldDescriptor.Type type) { + switch (type) { + case DOUBLE: + return Schema.FieldType.DOUBLE; + case FLOAT: + return Schema.FieldType.FLOAT; + case INT64: + case UINT64: + case SINT64: + case FIXED64: + case SFIXED64: + return Schema.FieldType.INT64; + case INT32: + case FIXED32: + case UINT32: + case SFIXED32: + case SINT32: + return Schema.FieldType.INT32; + case BOOL: + return Schema.FieldType.BOOLEAN; + case STRING: + case ENUM: + return Schema.FieldType.STRING; + case BYTES: + return Schema.FieldType.BYTES; + case MESSAGE: + case GROUP: + break; + } + throw new RuntimeException("Field type not matched."); + } + + Map convertOptions(Descriptors.FieldDescriptor protoField) { + Map metadata = new HashMap<>(); + DescriptorProtos.FieldOptions options = protoField.getOptions(); + options + .getAllFields() + .forEach( + (fd, value) -> { + String name = fd.getFullName(); + if (name.startsWith("google.protobuf.FieldOptions")) { + name = fd.getName(); + } + if (value instanceof Message) { + Message message = (Message) value; + Descriptors.Descriptor descriptorForType = message.getDescriptorForType(); + List fields = descriptorForType.getFields(); + for (Descriptors.FieldDescriptor field : fields) { + metadata.put( + name + "." + field.getName(), + message.getField(field).toString().getBytes(StandardCharsets.UTF_8)); + } + } else { + metadata.put(name, value.toString().getBytes(StandardCharsets.UTF_8)); + } + }); + + options + .getUnknownFields() + .asMap() + .forEach( + (ix, ufs) -> { + Descriptors.FieldDescriptor fieldOptionById = domain.getFieldOptionById(ix); + if (fieldOptionById != null) { + String name = fieldOptionById.getFullName(); + decodeUnknownOptionValue(metadata, name, fieldOptionById, ufs); + } + }); + return metadata; + } + + private void decodeUnknownOptionValue( + Map metadata, + String name, + Descriptors.FieldDescriptor fieldDescriptor, + UnknownFieldSet.Field value) { + + switch (fieldDescriptor.getType()) { + case MESSAGE: + break; + case FIXED64: + metadata.put( + name, + value.getFixed64List().stream() + .map( + l -> { + if (l >= 0) { + return Long.toString(l); + } else { + return BigInteger.valueOf(l & 0x7FFFFFFFFFFFFFFFL).setBit(63).toString(); + } + }) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case FIXED32: + break; + case BOOL: + metadata.put( + name, + value.getVarintList().stream() + .map(l -> Boolean.valueOf(l.intValue() > 0).toString()) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case ENUM: + metadata.put( + name, + value.getVarintList().stream() + .map(l -> fieldDescriptor.getEnumType().findValueByNumber(l.intValue()).getName()) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case STRING: + metadata.put( + name, + value.getLengthDelimitedList().stream() + .map(l -> (l.toString(StandardCharsets.UTF_8))) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case INT32: + case INT64: + case SINT32: + case SINT64: + case UINT32: + metadata.put( + name, + value.getVarintList().stream() + .map(l -> l.toString()) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case UINT64: + metadata.put( + name, + value.getVarintList().stream() + .map( + l -> { + if (l >= 0) { + return Long.toString(l); + } else { + return BigInteger.valueOf(l & 0x7FFFFFFFFFFFFFFFL).setBit(63).toString(); + } + }) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case DOUBLE: + metadata.put( + name, + value.getFixed64List().stream() + .map(l -> String.valueOf(Double.longBitsToDouble(l))) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case FLOAT: + metadata.put( + name, + value.getFixed32List().stream() + .map(l -> String.valueOf(Float.intBitsToFloat(l))) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case BYTES: + if (value.getLengthDelimitedList().size() > 0) { + metadata.put(name, value.getLengthDelimitedList().get(0).toByteArray()); + } + break; + case SFIXED32: + break; + case SFIXED64: + break; + case GROUP: + break; + default: + throw new IllegalStateException( + "Conversion of Unknown Field for type " + + fieldDescriptor.getType().toString() + + " not implemented"); + } + } + + private static boolean isMap(Descriptors.FieldDescriptor protoField) { + return protoField.getType() == Descriptors.FieldDescriptor.Type.MESSAGE + && protoField.getMessageType().getFullName().endsWith("Entry") + && (protoField.getMessageType().findFieldByName("key") != null) + && (protoField.getMessageType().findFieldByName("value") != null); + } + + ProtoFieldOverlay createFieldLayer(Descriptors.FieldDescriptor protoField, boolean nullable) { + Descriptors.FieldDescriptor.Type fieldDescriptor = protoField.getType(); + ProtoFieldOverlay fieldOverlay; + switch (fieldDescriptor) { + case DOUBLE: + case FLOAT: + case INT64: + case UINT64: + case SINT64: + case FIXED64: + case SFIXED64: + case INT32: + case FIXED32: + case UINT32: + case SFIXED32: + case SINT32: + case BOOL: + case STRING: + fieldOverlay = new ProtoFieldOverlay.PrimitiveOverlay(this, protoField); + break; + case BYTES: + fieldOverlay = new ProtoFieldOverlay.BytesOverlay(this, protoField); + break; + case ENUM: + fieldOverlay = new ProtoFieldOverlay.EnumOverlay(this, protoField); + break; + case MESSAGE: + String fullName = protoField.getMessageType().getFullName(); + if (typeMapping.containsKey(fullName)) { + Class aClass = typeMapping.get(fullName); + try { + Constructor constructor = aClass.getConstructor(Descriptors.FieldDescriptor.class); + return (ProtoFieldOverlay) constructor.newInstance(protoField); + } catch (NoSuchMethodException e) { + throw new RuntimeException("Unable to find constructor for Overlay mapper."); + } catch (IllegalAccessException | InstantiationException | InvocationTargetException e) { + throw new RuntimeException("Unable to invoke Overlay mapper."); + } + } + switch (fullName) { + case "google.protobuf.Timestamp": + return new ProtoFieldOverlay.TimestampOverlay(this, protoField); + case "google.protobuf.StringValue": + case "google.protobuf.DoubleValue": + case "google.protobuf.FloatValue": + case "google.protobuf.BoolValue": + case "google.protobuf.Int64Value": + case "google.protobuf.Int32Value": + case "google.protobuf.UInt64Value": + case "google.protobuf.UInt32Value": + case "google.protobuf.BytesValue": + return new ProtoFieldOverlay.WrapperOverlay(this, protoField); + case "google.protobuf.Duration": + default: + if (isMap(protoField)) { + return new ProtoFieldOverlay.MapOverlay(this, protoField); + } else { + return new ProtoFieldOverlay.MessageOverlay(this, protoField); + } + } + case GROUP: + default: + throw new RuntimeException("Field type not matched."); + } + if (nullable) { + return new ProtoFieldOverlay.NullableOverlay(this, protoField, fieldOverlay); + } + return fieldOverlay; + } + + private ArrayList createFieldLayer(Descriptors.Descriptor descriptor) { + // Oneof fields are nullable, even as they are primitive or enums + List oneofMap = + descriptor.getOneofs().stream() + .flatMap(oneofDescriptor -> oneofDescriptor.getFields().stream()) + .collect(Collectors.toList()); + + ArrayList fieldOverlays = new ArrayList<>(); + Iterator protoFields = descriptor.getFields().iterator(); + for (int i = 0; i < descriptor.getFields().size(); i++) { + Descriptors.FieldDescriptor protoField = protoFields.next(); + if (protoField.isRepeated() && !isMap(protoField)) { + fieldOverlays.add(new ProtoFieldOverlay.ArrayOverlay(this, protoField)); + } else { + fieldOverlays.add(createFieldLayer(protoField, oneofMap.contains(protoField))); + } + } + return fieldOverlays; + } + + private void init() { + this.getters = createFieldLayer(descriptor); + + Schema.Builder builder = Schema.builder(); + for (ProtoFieldOverlay field : getters) { + builder.addField(field.getSchemaField()); + } + + Schema schema = builder.build(); + schema.setUUID(UUID.randomUUID()); + schemaCoder = + SchemaCoder.of( + schema, + TypeDescriptor.of(rawType), + new MessageToRowFunction(), + new RowToMessageFunction()); + + globalSchemaCache.put(schema.getUUID(), this); + try { + if (DynamicMessage.class.equals(rawType)) { + this.fnNewBuilder = rawType.getMethod("newBuilder", Descriptors.Descriptor.class); + } else { + this.fnNewBuilder = rawType.getMethod("newBuilder"); + } + } catch (NoSuchMethodException e) { + } + } + + public Schema getSchema() { + return this.schemaCoder.getSchema(); + } + + public SchemaCoder getSchemaCoder() { + return schemaCoder; + } + + public SerializableFunction getToRowFunction() { + return schemaCoder.getToRowFunction(); + } + + public SerializableFunction getFromRowFunction() { + return schemaCoder.getFromRowFunction(); + } + + private void writeObject(ObjectOutputStream oos) throws IOException { + oos.defaultWriteObject(); + if (DynamicMessage.class.equals(this.rawType)) { + if (this.descriptor == null) { + throw new RuntimeException("DynamicMessages require provider a Descriptor to the coder."); + } + oos.writeUTF(descriptor.getFullName()); + } + } + + private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { + ois.defaultReadObject(); + if (DynamicMessage.class.equals(rawType)) { + descriptor = domain.getDescriptor(ois.readUTF()); + } else { + descriptor = ProtobufUtil.getDescriptorForClass(rawType); + } + init(); + } + + public ProtoDomain getDomain() { + return domain; + } + + public static class Builder implements Serializable { + + private ProtoDomain domain; + private Map mappings = new HashMap<>(); + + public Builder(ProtoDomain domain) { + this.domain = domain; + } + + public Builder addTypeMapping(Map mappings) { + this.mappings.putAll(mappings); + return this; + } + + public Builder addTypeMapping(String message, Class mappingClass) { + this.mappings.put(message, mappingClass); + return this; + } + + public ProtoSchema forType(Class rawType) { + return new ProtoSchema( + rawType, + ProtobufUtil.getDescriptorForClass(rawType), + domain, + ImmutableMap.copyOf(mappings)); + } + + public ProtoSchema forDescriptor(Descriptors.Descriptor descriptor) { + return new ProtoSchema( + DynamicMessage.class, descriptor, domain, ImmutableMap.copyOf(mappings)); + } + } + + /** Overlay. */ + public static class ProtoOverlayFactory implements Factory> { + + public ProtoOverlayFactory() {} + + @Override + public List create(Class clazz, Schema schema) { + return ProtoSchema.fromSchema(schema).getters; + } + } + + private class MessageToRowFunction implements SerializableFunction { + + private MessageToRowFunction() {} + + @Override + public Row apply(Message input) { + return Row.withSchema(schemaCoder.getSchema()) + .withFieldValueGettersHandleCollections(true) + .withFieldValueGetters(new ProtoOverlayFactory(), input) + .build(); + } + } + + private class RowToMessageFunction implements SerializableFunction { + + private RowToMessageFunction() {} + + @Override + public T apply(Row input) { + Message.Builder builder; + try { + if (DynamicMessage.class.equals(rawType)) { + builder = (Message.Builder) fnNewBuilder.invoke(rawType, descriptor); + } else { + builder = (Message.Builder) fnNewBuilder.invoke(rawType); + } + } catch (IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException("Can't invoke newBuilder on the Protobuf message class.", e); + } + + Iterator values = input.getValues().iterator(); + Iterator overlayIterator = getters.iterator(); + + for (int i = 0; i < input.getValues().size(); i++) { + ProtoFieldOverlay getter = overlayIterator.next(); + Object value = values.next(); + getter.set(builder, value); + } + return (T) builder.build(); + } + } +} diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaProvider.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaProvider.java new file mode 100644 index 000000000000..45503a1c6628 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaProvider.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import com.google.protobuf.DynamicMessage; +import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaProvider; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Schema provider for Protobuf messages. The provider is able to handle pre compiled Message file + * without external help. For Dynamic Messages a Descriptor needs to be registered up front on a + * specific URN. + * + *

It's possible to inherit this class for a specific implementation that communicates with an + * external registry that maps those URN's with Descriptors. + */ +@Experimental(Experimental.Kind.SCHEMAS) +public class ProtoSchemaProvider implements SchemaProvider { + private static final Logger LOG = LoggerFactory.getLogger(ProtoSchemaProvider.class); + + private final ProtoSchema.Builder protoSchemaBuilder; + + public ProtoSchemaProvider() { + this.protoSchemaBuilder = ProtoSchema.newBuilder(); + } + + public ProtoSchemaProvider(ProtoSchema.Builder protoSchemaBuilder) { + this.protoSchemaBuilder = protoSchemaBuilder; + } + + @Override + public Schema schemaFor(TypeDescriptor typeDescriptor) { + checkForDynamicType(typeDescriptor); + return protoSchemaBuilder.forType(typeDescriptor.getRawType()).getSchema(); + } + + @Nullable + @Override + public SerializableFunction toRowFunction(TypeDescriptor typeDescriptor) { + checkForDynamicType(typeDescriptor); + return protoSchemaBuilder + .forType(typeDescriptor.getRawType()) + .getSchemaCoder() + .getToRowFunction(); + } + + @Override + public SerializableFunction fromRowFunction(TypeDescriptor typeDescriptor) { + checkForDynamicType(typeDescriptor); + return protoSchemaBuilder + .forType(typeDescriptor.getRawType()) + .getSchemaCoder() + .getFromRowFunction(); + } + + private void checkForDynamicType(TypeDescriptor typeDescriptor) { + if (typeDescriptor.getRawType().equals(DynamicMessage.class)) { + throw new RuntimeException( + "DynamicMessage is not allowed for the standard ProtoSchemaProvider, use ProtoSchema build instead."); + } + } +} diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTest.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTest.java new file mode 100644 index 000000000000..b49ac5d5d5d6 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTest.java @@ -0,0 +1,577 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import static org.junit.Assert.assertEquals; + +import com.google.protobuf.Descriptors; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Message; +import java.io.IOException; +import java.util.Objects; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Collection of standard tests for Protobuf Schema support. */ +@RunWith(JUnit4.class) +public class ProtoSchemaTest { + + private static final Schema PRIMITIVE_SCHEMA = + Schema.builder() + .addDoubleField("primitive_double") + .addFloatField("primitive_float") + .addInt32Field("primitive_int32") + .addInt64Field("primitive_int64") + .addInt32Field("primitive_uint32") + .addInt64Field("primitive_uint64") + .addInt32Field("primitive_sint32") + .addInt64Field("primitive_sint64") + .addInt32Field("primitive_fixed32") + .addInt64Field("primitive_fixed64") + .addInt32Field("primitive_sfixed32") + .addInt64Field("primitive_sfixed64") + .addBooleanField("primitive_bool") + .addStringField("primitive_string") + .addByteArrayField("primitive_bytes") + .build(); + static final Row PRIMITIVE_DEFAULT_ROW = + Row.withSchema(PRIMITIVE_SCHEMA) + .addValue((double) 0) + .addValue((float) 0) + .addValue(0) + .addValue(0L) + .addValue(0) + .addValue(0L) + .addValue(0) + .addValue(0L) + .addValue(0) + .addValue(0L) + .addValue(0) + .addValue(0L) + .addValue(Boolean.FALSE) + .addValue("") + .addValue(new byte[] {}) + .build(); + static final Schema MESSAGE_SCHEMA = + Schema.builder() + .addField("message", Schema.FieldType.row(PRIMITIVE_SCHEMA).withNullable(true)) + .addField( + "repeated_message", + Schema.FieldType.array( + // TODO: are the nullable's correct + Schema.FieldType.row(PRIMITIVE_SCHEMA).withNullable(true)) + .withNullable(true)) + .build(); + private static final Row MESSAGE_DEFAULT_ROW = + Row.withSchema(MESSAGE_SCHEMA).addValue(null).addValue(null).build(); + private static final Schema REPEAT_PRIMITIVE_SCHEMA = + Schema.builder() + .addField( + "repeated_double", Schema.FieldType.array(Schema.FieldType.DOUBLE).withNullable(true)) + .addField( + "repeated_float", Schema.FieldType.array(Schema.FieldType.FLOAT).withNullable(true)) + .addField( + "repeated_int32", Schema.FieldType.array(Schema.FieldType.INT32).withNullable(true)) + .addField( + "repeated_int64", Schema.FieldType.array(Schema.FieldType.INT64).withNullable(true)) + .addField( + "repeated_uint32", Schema.FieldType.array(Schema.FieldType.INT32).withNullable(true)) + .addField( + "repeated_uint64", Schema.FieldType.array(Schema.FieldType.INT64).withNullable(true)) + .addField( + "repeated_sint32", Schema.FieldType.array(Schema.FieldType.INT32).withNullable(true)) + .addField( + "repeated_sint64", Schema.FieldType.array(Schema.FieldType.INT64).withNullable(true)) + .addField( + "repeated_fixed32", Schema.FieldType.array(Schema.FieldType.INT32).withNullable(true)) + .addField( + "repeated_fixed64", Schema.FieldType.array(Schema.FieldType.INT64).withNullable(true)) + .addField( + "repeated_sfixed32", + Schema.FieldType.array(Schema.FieldType.INT32).withNullable(true)) + .addField( + "repeated_sfixed64", + Schema.FieldType.array(Schema.FieldType.INT64).withNullable(true)) + .addField( + "repeated_bool", Schema.FieldType.array(Schema.FieldType.BOOLEAN).withNullable(true)) + .addField( + "repeated_string", Schema.FieldType.array(Schema.FieldType.STRING).withNullable(true)) + .addField( + "repeated_bytes", Schema.FieldType.array(Schema.FieldType.BYTES).withNullable(true)) + .build(); + static final Row REPEAT_PRIMITIVE_DEFAULT_ROW = + Row.withSchema(REPEAT_PRIMITIVE_SCHEMA) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .build(); + private static final Schema COMPLEX_SCHEMA = + Schema.builder() + .addField("special_enum", Schema.FieldType.STRING) + .addField( + "repeated_enum", Schema.FieldType.array(Schema.FieldType.STRING).withNullable(true)) + .addField("oneof_int32", Schema.FieldType.INT32.withNullable(true)) + .addField("oneof_bool", Schema.FieldType.BOOLEAN.withNullable(true)) + .addField("oneof_string", Schema.FieldType.STRING.withNullable(true)) + .addField("oneof_primitive", Schema.FieldType.row(PRIMITIVE_SCHEMA).withNullable(true)) + .addField( + "x", + Schema.FieldType.map(Schema.FieldType.STRING, Schema.FieldType.INT32) + .withNullable(true)) + .addField( + "y", + Schema.FieldType.map(Schema.FieldType.STRING, Schema.FieldType.STRING) + .withNullable(true)) + .addField( + "z", + // TODO: null in map, does it make sense. + Schema.FieldType.map( + Schema.FieldType.STRING, + Schema.FieldType.row(PRIMITIVE_SCHEMA).withNullable(true)) + .withNullable(true)) + .addField("oneof_int64", Schema.FieldType.INT64.withNullable(true)) + .addField("oneof_double", Schema.FieldType.DOUBLE.withNullable(true)) + .build(); + static final Row COMPLEX_DEFAULT_ROW = + Row.withSchema(COMPLEX_SCHEMA) + .addValue("UNKNOWN") + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .build(); + private static final Schema WKT_MESSAGE_SCHEMA = + Schema.builder() + .addField("nullable_double", Schema.FieldType.DOUBLE.withNullable(true)) + .addField("nullable_float", Schema.FieldType.FLOAT.withNullable(true)) + .addField("nullable_int32", Schema.FieldType.INT32.withNullable(true)) + .addField("nullable_int64", Schema.FieldType.INT64.withNullable(true)) + .addField("nullable_uint32", Schema.FieldType.INT32.withNullable(true)) + .addField("nullable_uint64", Schema.FieldType.INT64.withNullable(true)) + // xxx + .addField("nullable_bool", Schema.FieldType.BOOLEAN.withNullable(true)) + .addField("nullable_string", Schema.FieldType.STRING.withNullable(true)) + .addField("nullable_bytes", Schema.FieldType.BYTES.withNullable(true)) + // + .addField("wkt_timestamp", Schema.FieldType.DATETIME.withNullable(true)) + .build(); + static final Row WKT_MESSAGE_DEFAULT_ROW = + Row.withSchema(WKT_MESSAGE_SCHEMA) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .build(); + + @Test + public void testPrimitiveSchema() { + Schema schema = + new ProtoSchemaProvider() + .schemaFor(TypeDescriptor.of(Proto3SchemaMessages.Primitive.class)); + assertEquals(PRIMITIVE_SCHEMA, schema); + } + + @Test + public void testPrimitiveDefaultRow() { + SerializableFunction toRowFunction = + new ProtoSchemaProvider() + .toRowFunction(TypeDescriptor.of(Proto3SchemaMessages.Primitive.class)); + Row row = toRowFunction.apply(Proto3SchemaMessages.Primitive.newBuilder().build()); + assertEquals(PRIMITIVE_DEFAULT_ROW, row); + } + + @Test + public void testMessageSchema() { + Schema schema = + new ProtoSchemaProvider().schemaFor(TypeDescriptor.of(Proto3SchemaMessages.Message.class)); + assertEquals(MESSAGE_SCHEMA, schema); + } + + @Test + public void testMessageDefaultRow() { + SerializableFunction toRowFunction = + new ProtoSchemaProvider() + .toRowFunction(TypeDescriptor.of(Proto3SchemaMessages.Message.class)); + Row row = toRowFunction.apply(Proto3SchemaMessages.Message.newBuilder().build()); + assertEquals(MESSAGE_DEFAULT_ROW, row); + } + + @Test + public void testRepeatPrimitiveSchema() { + Schema schema = + new ProtoSchemaProvider() + .schemaFor(TypeDescriptor.of(Proto3SchemaMessages.RepeatPrimitive.class)); + assertEquals(REPEAT_PRIMITIVE_SCHEMA, schema); + } + + @Test + public void testRepeatPrimitiveDefaultRow() { + SerializableFunction toRowFunction = + new ProtoSchemaProvider() + .toRowFunction(TypeDescriptor.of(Proto3SchemaMessages.RepeatPrimitive.class)); + Row row = toRowFunction.apply(Proto3SchemaMessages.RepeatPrimitive.newBuilder().build()); + assertEquals(REPEAT_PRIMITIVE_DEFAULT_ROW, row); + } + + @Test + public void testComplexSchema() { + Schema schema = + new ProtoSchemaProvider().schemaFor(TypeDescriptor.of(Proto3SchemaMessages.Complex.class)); + assertEquals(COMPLEX_SCHEMA, schema); + } + + @Test + public void testComplexDefaultRow() { + SerializableFunction toRowFunction = + new ProtoSchemaProvider() + .toRowFunction(TypeDescriptor.of(Proto3SchemaMessages.Complex.class)); + Row row = toRowFunction.apply(Proto3SchemaMessages.Complex.newBuilder().build()); + assertEquals(COMPLEX_DEFAULT_ROW, row); + } + + @Test + public void testWktMessageSchema() { + Schema schema = + new ProtoSchemaProvider() + .schemaFor(TypeDescriptor.of(Proto3SchemaMessages.WktMessage.class)); + assertEquals(WKT_MESSAGE_SCHEMA, schema); + } + + @Test + public void testWktMessageDefaultRow() { + SerializableFunction toRowFunction = + new ProtoSchemaProvider() + .toRowFunction(TypeDescriptor.of(Proto3SchemaMessages.WktMessage.class)); + Row row = toRowFunction.apply(Proto3SchemaMessages.WktMessage.newBuilder().build()); + assertEquals(WKT_MESSAGE_DEFAULT_ROW, row); + } + + @Test + public void testCoder() throws Exception { + SchemaCoder schemaCoder = + ProtoSchema.newBuilder().forType(Proto3SchemaMessages.Complex.class).getSchemaCoder(); + RowCoder rowCoder = RowCoder.of(schemaCoder.getSchema()); + + byte[] schemaCoderBytes = SerializableUtils.serializeToByteArray(schemaCoder); + SchemaCoder schemaCoderCoded = + (SchemaCoder) SerializableUtils.deserializeFromByteArray(schemaCoderBytes, ""); + byte[] rowCoderBytes = SerializableUtils.serializeToByteArray(rowCoder); + RowCoder rowCoderCoded = + (RowCoder) SerializableUtils.deserializeFromByteArray(rowCoderBytes, ""); + + Proto3SchemaMessages.Complex message = + Proto3SchemaMessages.Complex.newBuilder() + .setOneofString("foobar") + .setSpecialEnum(Proto3SchemaMessages.Complex.EnumNested.FOO) + .build(); + + Row row = schemaCoder.getToRowFunction().apply(message); + byte[] rowBytes = CoderUtils.encodeToByteArray(rowCoder, row); + + Row rowCoded = CoderUtils.decodeFromByteArray(rowCoderCoded, rowBytes); + assertEquals(row, rowCoded); + + Message messageVerify = schemaCoder.getFromRowFunction().apply(rowCoded); + assertEquals(message, messageVerify); + + Message messageCoded = schemaCoderCoded.getFromRowFunction().apply(rowCoded); + assertEquals(message, messageCoded); + } + + @Test + public void testCoderOnDynamic() throws Exception { + Descriptors.Descriptor descriptor = Proto3SchemaMessages.Complex.getDescriptor(); + Descriptors.FieldDescriptor oneofString = descriptor.findFieldByName("oneof_string"); + Descriptors.FieldDescriptor specialEnum = descriptor.findFieldByName("special_enum"); + + SchemaCoder schemaCoder = + ProtoSchema.newBuilder(ProtoDomain.buildFrom(descriptor)) + .forDescriptor(descriptor) + .getSchemaCoder(); + RowCoder rowCoder = RowCoder.of(schemaCoder.getSchema()); + + byte[] schemaCoderBytes = SerializableUtils.serializeToByteArray(schemaCoder); + SchemaCoder schemaCoderCoded = + (SchemaCoder) SerializableUtils.deserializeFromByteArray(schemaCoderBytes, ""); + byte[] rowCoderBytes = SerializableUtils.serializeToByteArray(rowCoder); + RowCoder rowCoderCoded = + (RowCoder) SerializableUtils.deserializeFromByteArray(rowCoderBytes, ""); + + DynamicMessage message = + DynamicMessage.newBuilder(descriptor) + .setField(oneofString, "foobar") + .setField(specialEnum, Proto3SchemaMessages.Complex.EnumNested.FOO.getValueDescriptor()) + .build(); + + Row row = schemaCoder.getToRowFunction().apply(message); + byte[] rowBytes = CoderUtils.encodeToByteArray(rowCoder, row); + + Row rowCoded = CoderUtils.decodeFromByteArray(rowCoderCoded, rowBytes); + assertEquals(row, rowCoded); + + Message messageVerify = schemaCoder.getFromRowFunction().apply(rowCoded); + assertEquals(message, messageVerify); + + Message messageCoded = schemaCoderCoded.getFromRowFunction().apply(rowCoded); + Descriptors.FieldDescriptor oneofStringCoded = + messageCoded.getDescriptorForType().findFieldByName("oneof_string"); + Descriptors.FieldDescriptor specialEnumCoded = + messageCoded.getDescriptorForType().findFieldByName("special_enum"); + assertEquals(message.getField(oneofString), messageCoded.getField(oneofStringCoded)); + assertEquals( + ((Descriptors.EnumValueDescriptor) message.getField(specialEnum)).getFullName(), + ((Descriptors.EnumValueDescriptor) messageCoded.getField(specialEnumCoded)).getFullName()); + } + + @Test + public void testLogicalTypeRegistration() { + ProtoSchemaProvider protoSchemaProvider = + new ProtoSchemaProvider( + ProtoSchema.newBuilder() + .addTypeMapping("proto3_schema_messages.LatLng", LatLngOverlay.class)); + + SerializableFunction toRowFunction = + protoSchemaProvider.toRowFunction( + TypeDescriptor.of(Proto3SchemaMessages.LogicalTypes.class)); + Row row = + toRowFunction.apply( + Proto3SchemaMessages.LogicalTypes.newBuilder() + .setGps( + Proto3SchemaMessages.LatLng.newBuilder() + .setLatitude(1.2) + .setLongitude(3.5) + .build()) + .build()); + assertEquals(new LatLng(1.2, 3.5), row.getValue("gps")); + } + + /** Example Java type for testing LogicalTypes. */ + public static class LatLng { + double lat; + double lng; + + public LatLng(Double lat, Double lng) { + this.lat = lat; + this.lng = lng; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + LatLng latLng = (LatLng) o; + return Double.compare(latLng.lat, lat) == 0 && Double.compare(latLng.lng, lng) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(lat, lng); + } + } + + /** + * Example of a LogicalType converter. It will make sure that the type is convertible to a + * FieldType. + */ + public static class LatLngLogicalType implements Schema.LogicalType { + + static final Schema BASE_TYPE = + Schema.builder() + .addField("lat", Schema.FieldType.DOUBLE) + .addField("lng", Schema.FieldType.DOUBLE) + .build(); + + @Override + public String getIdentifier() { + return "LatLngLogicalType"; + } + + @Override + public Schema.FieldType getBaseType() { + return Schema.FieldType.row(LatLngLogicalType.BASE_TYPE); + } + + @Override + public Row toBaseType(LatLng input) { + return Row.withSchema(BASE_TYPE).addValue(input.lat).addValue(input.lng).build(); + } + + @Override + public LatLng toInputType(Row base) { + return new LatLng(base.getDouble("lat"), base.getDouble("lng")); + } + } + + /** Custom Protobuf field overlay that returns a custom LogicalType. */ + public static class LatLngOverlay implements ProtoFieldOverlay { + private Descriptors.FieldDescriptor fieldDescriptor; + private Descriptors.FieldDescriptor latitudeFieldDescriptor; + private Descriptors.FieldDescriptor longitudeFieldDescriptor; + + public LatLngOverlay(Descriptors.FieldDescriptor fieldDescriptor) { + this.fieldDescriptor = fieldDescriptor; + latitudeFieldDescriptor = fieldDescriptor.getMessageType().findFieldByName("latitude"); + longitudeFieldDescriptor = fieldDescriptor.getMessageType().findFieldByName("longitude"); + } + + @Override + public LatLng get(Message message) { + Message latLngMessage = (Message) message.getField(fieldDescriptor); + return new LatLng( + (double) latLngMessage.getField(latitudeFieldDescriptor), + (double) latLngMessage.getField(longitudeFieldDescriptor)); + } + + @Override + public String name() { + return fieldDescriptor.getName(); + } + + @Override + public LatLng convertGetObject(Descriptors.FieldDescriptor fieldDescriptor, Object object) { + return null; + } + + @Override + public void set(Message.Builder message, LatLng value) { + message.setField( + fieldDescriptor, + Proto3SchemaMessages.LatLng.newBuilder() + .setLongitude(value.lng) + .setLatitude(value.lat) + .build()); + } + + @Override + public Object convertSetObject(Descriptors.FieldDescriptor fieldDescriptor, Object value) { + return null; + } + + @Override + public Schema.Field getSchemaField() { + return Schema.Field.of( + fieldDescriptor.getName(), Schema.FieldType.logicalType(new LatLngLogicalType())); + } + } + + @Test + public void testMessageWithMetaSchema() { + Schema schema = + new ProtoSchemaProvider() + .schemaFor(TypeDescriptor.of(Proto3SchemaMessages.MessageWithMeta.class)); + Schema.Field fieldWithDescription = schema.getField("field_with_description"); + assertEquals( + "Cool field", + fieldWithDescription + .getType() + .getMetadataString("proto3_schema_messages.field_meta.description")); + assertEquals( + "0", + fieldWithDescription + .getType() + .getMetadataString("proto3_schema_messages.field_meta.foobar")); + assertEquals("", fieldWithDescription.getType().getMetadataString("deprecated")); + + Schema.Field fieldWithFoobar = schema.getField("field_with_foobar"); + assertEquals( + "", + fieldWithFoobar + .getType() + .getMetadataString("proto3_schema_messages.field_meta.description")); + assertEquals( + "42", + fieldWithFoobar.getType().getMetadataString("proto3_schema_messages.field_meta.foobar")); + assertEquals("", fieldWithFoobar.getType().getMetadataString("deprecated")); + + Schema.Field fieldWithDeprecation = schema.getField("field_with_deprecation"); + assertEquals( + "", + fieldWithDeprecation + .getType() + .getMetadataString("proto3_schema_messages.field_meta.description")); + assertEquals( + "", + fieldWithDeprecation + .getType() + .getMetadataString("proto3_schema_messages.field_meta.foobar")); + assertEquals("true", fieldWithDeprecation.getType().getMetadataString("deprecated")); + } + + @Test + public void testMessageWithMetaDynamicSchema() throws IOException { + ProtoDomain domain = ProtoDomain.buildFrom(getClass().getResourceAsStream("test_option_v1.pb")); + Descriptors.Descriptor descriptor = domain.getDescriptor("test.option.v1.MessageWithOptions"); + Schema schema = ProtoSchema.newBuilder(domain).forDescriptor(descriptor).getSchema(); + Schema.Field field; + field = schema.getField("field_with_fieldoption_double"); + assertEquals("100.1", field.getType().getMetadataString("test.option.v1.fieldoption_double")); + field = schema.getField("field_with_fieldoption_float"); + assertEquals("101.2", field.getType().getMetadataString("test.option.v1.fieldoption_float")); + field = schema.getField("field_with_fieldoption_int32"); + assertEquals("102", field.getType().getMetadataString("test.option.v1.fieldoption_int32")); + field = schema.getField("field_with_fieldoption_int64"); + assertEquals("103", field.getType().getMetadataString("test.option.v1.fieldoption_int64")); + field = schema.getField("field_with_fieldoption_bool"); + assertEquals("true", field.getType().getMetadataString("test.option.v1.fieldoption_bool")); + field = schema.getField("field_with_fieldoption_string"); + assertEquals("Oh yeah", field.getType().getMetadataString("test.option.v1.fieldoption_string")); + field = schema.getField("field_with_fieldoption_enum"); + assertEquals("ENUM1", field.getType().getMetadataString("test.option.v1.fieldoption_enum")); + field = schema.getField("field_with_fieldoption_repeated_string"); + assertEquals( + "Oh yeah\nOh no", + field.getType().getMetadataString("test.option.v1.fieldoption_repeated_string")); + } +} diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaValuesTest.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaValuesTest.java new file mode 100644 index 000000000000..7d9e8aa6ac7a --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaValuesTest.java @@ -0,0 +1,670 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTest.COMPLEX_DEFAULT_ROW; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTest.MESSAGE_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTest.PRIMITIVE_DEFAULT_ROW; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTest.REPEAT_PRIMITIVE_DEFAULT_ROW; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTest.WKT_MESSAGE_DEFAULT_ROW; +import static org.junit.Assert.assertEquals; + +import com.google.protobuf.BoolValue; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.FloatValue; +import com.google.protobuf.Int32Value; +import com.google.protobuf.Int64Value; +import com.google.protobuf.Message; +import com.google.protobuf.StringValue; +import com.google.protobuf.Timestamp; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +/** Collection of tests for values on Protobuf Messages and Rows. */ +@RunWith(Parameterized.class) +public class ProtoSchemaValuesTest { + + private final Message proto; + private final Row rowObject; + private SerializableFunction toRowFunction; + private SerializableFunction fromRowFunction; + + public ProtoSchemaValuesTest(String description, Message proto, Row rowObject) { + this.proto = proto; + this.rowObject = rowObject; + } + + @Parameterized.Parameters(name = "{0}") + public static Collection data() { + List data = new ArrayList<>(); + data.add( + new Object[] { + "primitive_int32", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt32(Integer.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_int32", Integer.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_int64", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt64(Long.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_int64", Long.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_uint32", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveUint32(Integer.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_uint32", Integer.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_uint64", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveUint64(Long.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_uint64", Long.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_sint32", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveSint32(Integer.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_sint32", Integer.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_sint64", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveSint64(Long.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_sint64", Long.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_fixed32", + Proto3SchemaMessages.Primitive.newBuilder() + .setPrimitiveFixed32(Integer.MAX_VALUE) + .build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_fixed32", Integer.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_fixed64", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveFixed64(Long.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_fixed64", Long.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_sfixed32", + Proto3SchemaMessages.Primitive.newBuilder() + .setPrimitiveSfixed32(Integer.MAX_VALUE) + .build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_sfixed32", Integer.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_sfixed64", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveSfixed64(Long.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_sfixed64", Long.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_bool", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveBool(true).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_bool", true) + }); + data.add( + new Object[] { + "primitive_string", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveString("lovely string").build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_string", "lovely string") + }); + data.add( + new Object[] { + "primitive_bytes", + Proto3SchemaMessages.Primitive.newBuilder() + .setPrimitiveBytes(ByteString.copyFrom(new byte[] {(byte) 0x0F})) + .build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_bytes", new byte[] {(byte) 0x0F}) + }); + + data.add( + new Object[] { + "repeated_double", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedDouble(Double.MAX_VALUE) + .addRepeatedDouble(0.0) + .addRepeatedDouble(Double.MIN_VALUE) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, + "repeated_double", + Arrays.asList(Double.MAX_VALUE, 0.0, Double.MIN_VALUE)) + }); + data.add( + new Object[] { + "repeated_float", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedFloat(Float.MAX_VALUE) + .addRepeatedFloat((float) 0.0) + .addRepeatedFloat(Float.MIN_VALUE) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, + "repeated_float", + Arrays.asList(Float.MAX_VALUE, (float) 0.0, Float.MIN_VALUE)) + }); + data.add( + new Object[] { + "repeated_int32", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedInt32(Integer.MAX_VALUE) + .addRepeatedInt32(0) + .addRepeatedInt32(Integer.MIN_VALUE) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, + "repeated_int32", + Arrays.asList(Integer.MAX_VALUE, 0, Integer.MIN_VALUE)) + }); + data.add( + new Object[] { + "repeated_int64", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedInt64(Long.MAX_VALUE) + .addRepeatedInt64(0L) + .addRepeatedInt64(Long.MIN_VALUE) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, + "repeated_int64", + Arrays.asList(Long.MAX_VALUE, 0L, Long.MIN_VALUE)) + }); + + data.add( + new Object[] { + "repeated_uint32", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedUint32(Integer.MAX_VALUE) + .addRepeatedUint32(0) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_uint32", Arrays.asList(Integer.MAX_VALUE, 0)) + }); + data.add( + new Object[] { + "repeated_uint64", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedUint64(Long.MAX_VALUE) + .addRepeatedUint64(0L) + .build(), + change(REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_uint64", Arrays.asList(Long.MAX_VALUE, 0L)) + }); + + data.add( + new Object[] { + "repeated_sint32", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedSint32(Integer.MAX_VALUE) + .addRepeatedSint32(0) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_sint32", Arrays.asList(Integer.MAX_VALUE, 0)) + }); + data.add( + new Object[] { + "repeated_sint64", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedSint64(Long.MAX_VALUE) + .addRepeatedSint64(0L) + .build(), + change(REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_sint64", Arrays.asList(Long.MAX_VALUE, 0L)) + }); + data.add( + new Object[] { + "repeated_fixed32", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedFixed32(Integer.MAX_VALUE) + .addRepeatedFixed32(0) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_fixed32", Arrays.asList(Integer.MAX_VALUE, 0)) + }); + data.add( + new Object[] { + "repeated_fixed64", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedFixed64(Long.MAX_VALUE) + .addRepeatedFixed64(0) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_fixed64", Arrays.asList(Long.MAX_VALUE, 0L)) + }); + data.add( + new Object[] { + "repeated_sfixed32", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedSfixed32(Integer.MAX_VALUE) + .addRepeatedSfixed32(0) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, + "repeated_sfixed32", + Arrays.asList(Integer.MAX_VALUE, 0)) + }); + data.add( + new Object[] { + "repeated_sfixed64", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedSfixed64(Long.MAX_VALUE) + .addRepeatedSfixed64(0L) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_sfixed64", Arrays.asList(Long.MAX_VALUE, 0L)) + }); + data.add( + new Object[] { + "repeated_bool", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedBool(true) + .addRepeatedBool(false) + .build(), + change(REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_bool", Arrays.asList(true, false)) + }); + data.add( + new Object[] { + "repeated_string", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedString("foo") + .addRepeatedString("bar") + .build(), + change(REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_string", Arrays.asList("foo", "bar")) + }); + data.add( + new Object[] { + "repeated_bytes", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedBytes(ByteString.copyFrom(new byte[] {(byte) 0x0F})) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, + "repeated_bytes", + Arrays.asList(new byte[][] {new byte[] {(byte) 0x0F}})) + }); + + data.add( + new Object[] { + "nullable_double_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableDouble(DoubleValue.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_double", 0.0) + }); + data.add( + new Object[] { + "nullable_double", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableDouble(DoubleValue.newBuilder().setValue(Double.MAX_VALUE).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_double", Double.MAX_VALUE) + }); + data.add( + new Object[] { + "nullable_float_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableFloat(FloatValue.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_float", (float) 0) + }); + data.add( + new Object[] { + "nullable_float", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableFloat(FloatValue.newBuilder().setValue(Float.MAX_VALUE).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_float", Float.MAX_VALUE) + }); + data.add( + new Object[] { + "nullable_int32", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableInt32(Int32Value.newBuilder().setValue(Integer.MAX_VALUE).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_int32", Integer.MAX_VALUE) + }); + data.add( + new Object[] { + "nullable_int32_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableInt32(Int32Value.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_int32", 0) + }); + data.add( + new Object[] { + "nullable_int64", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableInt64(Int64Value.newBuilder().setValue(Long.MAX_VALUE).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_int64", Long.MAX_VALUE) + }); + data.add( + new Object[] { + "nullable_int64_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableInt64(Int64Value.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_int64", 0L) + }); + data.add( + new Object[] { + "nullable_uint32", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableUint32(UInt32Value.newBuilder().setValue(Integer.MAX_VALUE).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_uint32", Integer.MAX_VALUE) + }); + data.add( + new Object[] { + "nullable_uint32_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableUint32(UInt32Value.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_uint32", 0) + }); + data.add( + new Object[] { + "nullable_uint64", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableUint64(UInt64Value.newBuilder().setValue(Long.MAX_VALUE).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_uint64", Long.MAX_VALUE) + }); + data.add( + new Object[] { + "nullable_uint64_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableUint64(UInt64Value.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_uint64", 0L) + }); + data.add( + new Object[] { + "nullable_bool", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableBool(BoolValue.newBuilder().setValue(Boolean.TRUE).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_bool", Boolean.TRUE) + }); + data.add( + new Object[] { + "nullable_bool_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableBool(BoolValue.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_bool", Boolean.FALSE) + }); + data.add( + new Object[] { + "nullable_string", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableString(StringValue.newBuilder().setValue("bar").build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_string", "bar") + }); + data.add( + new Object[] { + "nullable_string_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableString(StringValue.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_string", "") + }); + data.add( + new Object[] { + "nullable_bytes", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableBytes( + BytesValue.newBuilder() + .setValue(ByteString.copyFrom(new byte[] {(byte) 0x0F})) + .build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_bytes", new byte[] {(byte) 0x0F}) + }); + data.add( + new Object[] { + "nullable_bytes_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableBytes(BytesValue.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_bytes", new byte[] {}) + }); + data.add( + new Object[] { + "wkt_timestamp", + Proto3SchemaMessages.WktMessage.newBuilder() + .setWktTimestamp( + Timestamp.newBuilder().setSeconds(1558680742).setNanos(123000000).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "wkt_timestamp", new Instant(1558680742123L)) + }); + data.add( + new Object[] { + "wkt_timestamp_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setWktTimestamp(Timestamp.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "wkt_timestamp", new Instant(0)) + }); + + data.add( + new Object[] { + "special_enum", + Proto3SchemaMessages.Complex.newBuilder() + .setSpecialEnum(Proto3SchemaMessages.Complex.EnumNested.FOO) + .build(), + change(COMPLEX_DEFAULT_ROW, "special_enum", "FOO") + }); + data.add( + new Object[] { + "repeated_enum", + Proto3SchemaMessages.Complex.newBuilder() + .addRepeatedEnum(Proto3SchemaMessages.Complex.EnumNested.FOO) + .addRepeatedEnum(Proto3SchemaMessages.Complex.EnumNested.BAR) + .build(), + change(COMPLEX_DEFAULT_ROW, "repeated_enum", Arrays.asList("FOO", "BAR")) + }); + data.add( + new Object[] { + "oneof_int32", + Proto3SchemaMessages.Complex.newBuilder().setOneofInt32(42).build(), + change(COMPLEX_DEFAULT_ROW, "oneof_int32", 42) + }); + data.add( + new Object[] { + "oneof_bool", + Proto3SchemaMessages.Complex.newBuilder().setOneofBool(true).build(), + change(COMPLEX_DEFAULT_ROW, "oneof_bool", Boolean.TRUE) + }); + data.add( + new Object[] { + "oneof_string", + Proto3SchemaMessages.Complex.newBuilder().setOneofString("one_of_string").build(), + change(COMPLEX_DEFAULT_ROW, "oneof_string", "one_of_string") + }); + data.add( + new Object[] { + "oneof_primitive", + Proto3SchemaMessages.Complex.newBuilder() + .setOneofPrimitive( + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt32(42).build()) + .build(), + change( + COMPLEX_DEFAULT_ROW, + "oneof_primitive", + change(PRIMITIVE_DEFAULT_ROW, "primitive_int32", 42)) + }); + Map mapInt = new HashMap<>(); + mapInt.put("one", 1); + mapInt.put("two", 2); + data.add( + new Object[] { + "map_int", + Proto3SchemaMessages.Complex.newBuilder().putX("one", 1).putX("two", 2).build(), + change(COMPLEX_DEFAULT_ROW, "x", mapInt) + }); + Map mapString = new HashMap<>(); + mapString.put("one", "eno"); + mapString.put("two", "owt"); + data.add( + new Object[] { + "map_int", + Proto3SchemaMessages.Complex.newBuilder().putY("one", "eno").putY("two", "owt").build(), + change(COMPLEX_DEFAULT_ROW, "y", mapString) + }); + Map mapRow = new HashMap<>(); + mapRow.put("one", change(PRIMITIVE_DEFAULT_ROW, "primitive_int32", 1)); + mapRow.put("two", change(PRIMITIVE_DEFAULT_ROW, "primitive_string", "two")); + data.add( + new Object[] { + "map_row", + Proto3SchemaMessages.Complex.newBuilder() + .putZ("one", Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt32(1).build()) + .putZ( + "two", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveString("two").build()) + .build(), + change(COMPLEX_DEFAULT_ROW, "z", mapRow) + }); + data.add( + new Object[] { + "subRow", + Proto3SchemaMessages.Message.newBuilder() + .setMessage( + Proto3SchemaMessages.Primitive.newBuilder() + .setPrimitiveString("we love strings") + .build()) + .build(), + Row.withSchema(MESSAGE_SCHEMA) + .addValue(change(PRIMITIVE_DEFAULT_ROW, "primitive_string", "we love strings")) + .addValue(null) + .build() + }); + data.add( + new Object[] { + "subRow+subArrayOfRow", + Proto3SchemaMessages.Message.newBuilder() + .setMessage(Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt32(42).build()) + .addRepeatedMessage( + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt32(69).build()) + .addRepeatedMessage( + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt32(70).build()) + .addRepeatedMessage( + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt32(71).build()) + .build(), + Row.withSchema(MESSAGE_SCHEMA) + .addValue(change(PRIMITIVE_DEFAULT_ROW, "primitive_int32", 42)) + .addArray( + change(PRIMITIVE_DEFAULT_ROW, "primitive_int32", 69), + change(PRIMITIVE_DEFAULT_ROW, "primitive_int32", 70), + change(PRIMITIVE_DEFAULT_ROW, "primitive_int32", 71)) + .build() + }); + return data; + } + + private static Row change(Row row, Object field, Object value) { + int index = -1; + List fields = row.getSchema().getFields(); + for (int i = 0; i < fields.size(); i++) { + Schema.Field f = fields.get(i); + if (f.getName().equals(field)) { + index = i; + break; + } + } + + Object[] objects = row.getValues().toArray(); + objects[index] = value; + return Row.withSchema(row.getSchema()).addValues(objects).build(); + } + + private void setup() { + ProtoSchemaProvider protoSchemaProvider = new ProtoSchemaProvider(); + TypeDescriptor typeDescriptor = TypeDescriptor.of(this.proto.getClass()); + + toRowFunction = protoSchemaProvider.toRowFunction(typeDescriptor); + fromRowFunction = protoSchemaProvider.fromRowFunction(typeDescriptor); + } + + private void setupForDynamicMessage() { + ProtoDomain domain = ProtoDomain.buildFrom(proto.getDescriptorForType()); + ProtoSchema protoSchema = + ProtoSchema.newBuilder(domain).forDescriptor(proto.getDescriptorForType()); + + toRowFunction = protoSchema.getSchemaCoder().getToRowFunction(); + fromRowFunction = protoSchema.getSchemaCoder().getFromRowFunction(); + } + + @Test + public void testRowAndBack() { + setup(); + Row row = toRowFunction.apply(this.proto); + Message message = fromRowFunction.apply(row); + assertEquals(proto, message); + } + + @Test + public void testToRow() { + setup(); + Row row = toRowFunction.apply(this.proto); + assertEquals(rowObject, row); + } + + @Test + public void testFromRow() { + setup(); + Message message = fromRowFunction.apply(this.rowObject); + assertEquals(this.proto, message); + } + + @Test + public void testToRowFromDynamicMessage() { + setupForDynamicMessage(); + Row row = toRowFunction.apply(DynamicMessage.newBuilder(this.proto).build()); + assertEquals(rowObject, row); + } + + @Test + public void testFromRowToDynamicMessage() { + setupForDynamicMessage(); + Message message = fromRowFunction.apply(this.rowObject); + assertEquals(this.proto, message); + } +} diff --git a/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto b/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto new file mode 100644 index 000000000000..68eaa1b2c7bc --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Protocol Buffer messages used for testing Proto3 Schema implementation. + */ + +syntax = "proto3"; + +package proto3_schema_messages; + +import "google/protobuf/timestamp.proto"; +import "google/protobuf/wrappers.proto"; +import "google/protobuf/descriptor.proto"; + +option java_package = "org.apache.beam.sdk.extensions.protobuf"; + +message MessageMeta { + string description = 1; + bool rewrite = 2; +} + +message FieldMeta { + string description = 1; + int32 foobar = 2; +} + +extend google.protobuf.MessageOptions { + MessageMeta message_meta = 66600666; +} + +extend google.protobuf.FieldOptions { + FieldMeta field_meta = 66600666; +} + +message Message { + Primitive message = 3; + repeated Primitive repeated_message = 4; +} + +message Primitive { + double primitive_double = 3; + float primitive_float = 4; + int32 primitive_int32 = 5; + int64 primitive_int64 = 6; + uint32 primitive_uint32 = 7; + uint64 primitive_uint64 = 8; + sint32 primitive_sint32 = 9; + sint64 primitive_sint64 = 10; + fixed32 primitive_fixed32 = 11; + fixed64 primitive_fixed64 = 12; + sfixed32 primitive_sfixed32 = 13; + sfixed64 primitive_sfixed64 = 14; + bool primitive_bool = 15; + string primitive_string = 16; + bytes primitive_bytes = 17; +} + +message RepeatPrimitive { + repeated double repeated_double = 1; + repeated float repeated_float = 2; + repeated int32 repeated_int32 = 3; + repeated int64 repeated_int64 = 4; + repeated uint32 repeated_uint32 = 5; + repeated uint64 repeated_uint64 = 6; + repeated sint32 repeated_sint32 = 7; + repeated sint64 repeated_sint64 = 8; + repeated fixed32 repeated_fixed32 = 9; + repeated fixed64 repeated_fixed64 = 10; + repeated sfixed32 repeated_sfixed32 = 11; + repeated sfixed64 repeated_sfixed64 = 12; + repeated bool repeated_bool = 13; + repeated string repeated_string = 14; + repeated bytes repeated_bytes = 15; +} + +message Complex { + enum EnumNested { + UNKNOWN = 0; + FOO = 1; + BAR = 2; + } + + EnumNested special_enum = 3; + repeated EnumNested repeated_enum = 4; + + oneof special_oneof { + int32 oneof_int32 = 5; + bool oneof_bool = 6; + string oneof_string = 7; + Primitive oneof_primitive = 8; + } + + map x = 9; + map y = 10; + map z = 11; + + oneof second_oneof { + int64 oneof_int64 = 12; + double oneof_double = 13; + } +} + +message WktMessage { + google.protobuf.DoubleValue nullable_double = 1; + google.protobuf.FloatValue nullable_float = 2; + google.protobuf.Int32Value nullable_int32 = 3; + google.protobuf.Int64Value nullable_int64 = 4; + google.protobuf.UInt32Value nullable_uint32 = 5; + google.protobuf.UInt64Value nullable_uint64 = 6; + google.protobuf.BoolValue nullable_bool = 13; + google.protobuf.StringValue nullable_string = 14; + google.protobuf.BytesValue nullable_bytes = 15; + google.protobuf.Timestamp wkt_timestamp = 16; +} + +message LatLng { + double latitude = 1; + double longitude = 2; +} + +message LogicalTypes { + LatLng gps = 1; +} + +message MessageWithMeta { + option (proto3_schema_messages.message_meta).description = "Cool field"; + option (proto3_schema_messages.message_meta).rewrite = true; + + string field_with_description = 1 [(proto3_schema_messages.field_meta).description = "Cool field"]; + string field_with_foobar = 2 [(proto3_schema_messages.field_meta).foobar = 42]; + string field_with_deprecation = 3 [deprecated = true]; +} + diff --git a/sdks/java/extensions/protobuf/src/test/resources/README.md b/sdks/java/extensions/protobuf/src/test/resources/README.md new file mode 100644 index 000000000000..79083f5142b0 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/resources/README.md @@ -0,0 +1,34 @@ + + +This recreates the proto descriptor set included in this resource directory. + +```bash +export PROTO_INCLUDE= +``` +Execute the following command to create the pb files, in the beam root folder: + +```bash +protoc \ + -Isdks/java/extensions/protobuf/src/test/resources/ \ + -I$PROTO_INCLUDE \ + --descriptor_set_out=sdks/java/extensions/protobuf/src/test/resources/org/apache/beam/sdk/extensions/protobuf/test_option_v1.pb \ + --include_imports \ + sdks/java/extensions/protobuf/src/test/resources/test/option/v1/simple.proto +``` diff --git a/sdks/java/extensions/protobuf/src/test/resources/org/apache/beam/sdk/extensions/protobuf/test_option_v1.pb b/sdks/java/extensions/protobuf/src/test/resources/org/apache/beam/sdk/extensions/protobuf/test_option_v1.pb new file mode 100644 index 000000000000..4e97ad02a15b Binary files /dev/null and b/sdks/java/extensions/protobuf/src/test/resources/org/apache/beam/sdk/extensions/protobuf/test_option_v1.pb differ diff --git a/sdks/java/extensions/protobuf/src/test/resources/test/option/v1/option.proto b/sdks/java/extensions/protobuf/src/test/resources/test/option/v1/option.proto new file mode 100644 index 000000000000..ca40119dce3f --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/resources/test/option/v1/option.proto @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package test.option.v1; + +import "google/protobuf/descriptor.proto"; + +extend google.protobuf.FileOptions { + double fileoption_double = 66666700; + float fileoption_float = 66666701; + int32 fileoption_int32 = 66666702; + int64 fileoption_int64 = 66666703; + uint32 fileoption_uint32 = 66666704; + uint64 fileoption_uint64 = 66666705; + sint32 fileoption_sint32 = 66666706; + sint64 fileoption_sint64 = 66666707; + fixed32 fileoption_fixed32 = 66666708; + fixed64 fileoption_fixed64 = 66666709; + sfixed32 fileoption_sfixed32 = 66666710; + sfixed64 fileoption_sfixed64 = 66666711; + bool fileoption_bool = 66666712; + string fileoption_string = 66666713; + bytes fileoption_bytes = 66666714; + OptionMessage fileoption_message = 66666715; + OptionEnum fileoption_enum = 66666716; +} + +extend google.protobuf.MessageOptions { + double messageoption_double = 66666700; + float messageoption_float = 66666701; + int32 messageoption_int32 = 66666702; + int64 messageoption_int64 = 66666703; + uint32 messageoption_uint32 = 66666704; + uint64 messageoption_uint64 = 66666705; + sint32 messageoption_sint32 = 66666706; + sint64 messageoption_sint64 = 66666707; + fixed32 messageoption_fixed32 = 66666708; + fixed64 messageoption_fixed64 = 66666709; + sfixed32 messageoption_sfixed32 = 66666710; + sfixed64 messageoption_sfixed64 = 66666711; + bool messageoption_bool = 66666712; + string messageoption_string = 66666713; + bytes messageoption_bytes = 66666714; + OptionMessage messageoption_message = 66666715; + OptionEnum messageoption_enum = 66666716; + + repeated double messageoption_repeated_double = 66666800; + repeated float messageoption_repeated_float = 66666801; + repeated int32 messageoption_repeated_int32 = 66666802; + repeated int64 messageoption_repeated_int64 = 66666803; + repeated uint32 messageoption_repeated_uint32 = 66666804; + repeated uint64 messageoption_repeated_uint64 = 66666805; + repeated sint32 messageoption_repeated_sint32 = 66666806; + repeated sint64 messageoption_repeated_sint64 = 66666807; + repeated fixed32 messageoption_repeated_fixed32 = 66666808; + repeated fixed64 messageoption_repeated_fixed64 = 66666809; + repeated sfixed32 messageoption_repeated_sfixed32 = 66666810; + repeated sfixed64 messageoption_repeated_sfixed64 = 66666811; + repeated bool messageoption_repeated_bool = 66666812; + repeated string messageoption_repeated_string = 66666813; + repeated bytes messageoption_repeated_bytes = 66666814; + repeated OptionMessage messageoption_repeated_message = 66666815; + repeated OptionEnum messageoption_repeated_enum = 66666816; +} + +extend google.protobuf.FieldOptions { + double fieldoption_double = 66666700; + float fieldoption_float = 66666701; + int32 fieldoption_int32 = 66666702; + int64 fieldoption_int64 = 66666703; + uint32 fieldoption_uint32 = 66666704; + uint64 fieldoption_uint64 = 66666705; + sint32 fieldoption_sint32 = 66666706; + sint64 fieldoption_sint64 = 66666707; + fixed32 fieldoption_fixed32 = 66666708; + fixed64 fieldoption_fixed64 = 66666709; + sfixed32 fieldoption_sfixed32 = 66666710; + sfixed64 fieldoption_sfixed64 = 66666711; + bool fieldoption_bool = 66666712; + string fieldoption_string = 66666713; + bytes fieldoption_bytes = 66666714; + OptionMessage fieldoption_message = 66666715; + OptionEnum fieldoption_enum = 66666716; + + repeated double fieldoption_repeated_double = 66666800; + repeated float fieldoption_repeated_float = 66666801; + repeated int32 fieldoption_repeated_int32 = 66666802; + repeated int64 fieldoption_repeated_int64 = 66666803; + repeated uint32 fieldoption_repeated_uint32 = 66666804; + repeated uint64 fieldoption_repeated_uint64 = 66666805; + repeated sint32 fieldoption_repeated_sint32 = 66666806; + repeated sint64 fieldoption_repeated_sint64 = 66666807; + repeated fixed32 fieldoption_repeated_fixed32 = 66666808; + repeated fixed64 fieldoption_repeated_fixed64 = 66666809; + repeated sfixed32 fieldoption_repeated_sfixed32 = 66666810; + repeated sfixed64 fieldoption_repeated_sfixed64 = 66666811; + repeated bool fieldoption_repeated_bool = 66666812; + repeated string fieldoption_repeated_string = 66666813; + repeated bytes fieldoption_repeated_bytes = 66666814; + repeated OptionMessage fieldoption_repeated_message = 66666815; + repeated OptionEnum fieldoption_repeated_enum = 66666816; +} + +enum OptionEnum { + DEFAULT = 0; + ENUM1 = 1; + ENUM2 = 2; +} + +message OptionMessage { + string string = 1; + repeated string repeated_string = 2; + + int32 int32 = 3; + repeated int32 repeated_int32 = 4; + + int64 int64 = 5; + + OptionEnum test_enum = 6; +} \ No newline at end of file diff --git a/sdks/java/extensions/protobuf/src/test/resources/test/option/v1/simple.proto b/sdks/java/extensions/protobuf/src/test/resources/test/option/v1/simple.proto new file mode 100644 index 000000000000..1750ddfb3ca5 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/resources/test/option/v1/simple.proto @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +import "test/option/v1/option.proto"; + +package test.option.v1; + +message MessageWithOptions { + string test_name = 1; + int32 test_index = 2; + + int32 field_with_fieldoption_double = 700 [(test.option.v1.fieldoption_double) = 100.1]; + int32 field_with_fieldoption_float = 701 [(test.option.v1.fieldoption_float) = 101.2]; + int32 field_with_fieldoption_int32 = 702 [(test.option.v1.fieldoption_int32) = 102]; + int32 field_with_fieldoption_int64 = 703 [(test.option.v1.fieldoption_int64) = 103]; + int32 field_with_fieldoption_uint32 = 704 [(test.option.v1.fieldoption_uint32) = 104]; + int32 field_with_fieldoption_uint64 = 705 [(test.option.v1.fieldoption_uint64) = 105]; + int32 field_with_fieldoption_sint32 = 706 [(test.option.v1.fieldoption_sint32) = 106]; + int32 field_with_fieldoption_sint64 = 707 [(test.option.v1.fieldoption_sint64) = 107]; + int32 field_with_fieldoption_fixed32 = 708; + int32 field_with_fieldoption_fixed64 = 709; + int32 field_with_fieldoption_sfixed32 = 710; + int32 field_with_fieldoption_sfixed64 = 711; + int32 field_with_fieldoption_bool = 712 [(test.option.v1.fieldoption_bool) = true]; + int32 field_with_fieldoption_string = 713 [(test.option.v1.fieldoption_string) = "Oh yeah"]; + int32 field_with_fieldoption_bytes = 714; + int32 field_with_fieldoption_message = 715; + int32 field_with_fieldoption_enum = 716 [(test.option.v1.fieldoption_enum) = ENUM1]; + + int32 field_with_fieldoption_repeated_double = 800; + int32 field_with_fieldoption_repeated_float = 801; + int32 field_with_fieldoption_repeated_int32 = 802; + int32 field_with_fieldoption_repeated_int64 = 803; + int32 field_with_fieldoption_repeated_uint32 = 804; + int32 field_with_fieldoption_repeated_uint64 = 805; + int32 field_with_fieldoption_repeated_sint32 = 806; + int32 field_with_fieldoption_repeated_sint64 = 807; + int32 field_with_fieldoption_repeated_fixed32 = 808; + int32 field_with_fieldoption_repeated_fixed64 = 809; + int32 field_with_fieldoption_repeated_sfixed32 = 810; + int32 field_with_fieldoption_repeated_sfixed64 = 811; + int32 field_with_fieldoption_repeated_bool = 812; + int32 field_with_fieldoption_repeated_string = 813 [(test.option.v1.fieldoption_repeated_string) = "Oh yeah", + (test.option.v1.fieldoption_repeated_string) = "Oh no"]; + int32 field_with_fieldoption_repeated_bytes = 814; + int32 field_with_fieldoption_repeated_message = 815; + int32 field_with_fieldoption_repeated_enum = 816; + +} +