From 10002a5300bc5784821912b6c8d29ff8baabbbb5 Mon Sep 17 00:00:00 2001 From: Alex Van Boxel Date: Sat, 11 May 2019 16:56:20 +0200 Subject: [PATCH 1/2] [BEAM-7274] Implement the Protobuf schema provider The default implementation for Protobuf messages. Supports both generated as dynamic messages. Add the initial release of handling of options, converting it to schema metadata. The implementation is not feature complete, be shows the viability of the feature. --- .../java/org/apache/beam/sdk/values/Row.java | 10 +- .../beam/sdk/values/RowWithGetters.java | 17 +- .../protobuf/ProtoFieldOverlay.java | 525 ++++++++++++++ .../sdk/extensions/protobuf/ProtoSchema.java | 569 +++++++++++++++ .../protobuf/ProtoSchemaProvider.java | 84 +++ .../extensions/protobuf/ProtoSchemaTest.java | 577 +++++++++++++++ .../protobuf/ProtoSchemaValuesTest.java | 670 ++++++++++++++++++ .../test/proto/proto3_schema_messages.proto | 149 ++++ .../protobuf/src/test/resources/README.md | 34 + .../sdk/extensions/protobuf/test_option_v1.pb | Bin 0 -> 18745 bytes .../resources/test/option/v1/option.proto | 137 ++++ .../resources/test/option/v1/simple.proto | 67 ++ 12 files changed, 2833 insertions(+), 6 deletions(-) create mode 100644 sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoFieldOverlay.java create mode 100644 sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchema.java create mode 100644 sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaProvider.java create mode 100644 sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTest.java create mode 100644 sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaValuesTest.java create mode 100644 sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto create mode 100644 sdks/java/extensions/protobuf/src/test/resources/README.md create mode 100644 sdks/java/extensions/protobuf/src/test/resources/org/apache/beam/sdk/extensions/protobuf/test_option_v1.pb create mode 100644 sdks/java/extensions/protobuf/src/test/resources/test/option/v1/option.proto create mode 100644 sdks/java/extensions/protobuf/src/test/resources/test/option/v1/simple.proto diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index 48c708c75fab..8e5a7794bc8e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -499,6 +499,7 @@ public static class Builder { @Nullable private Factory> fieldValueGetterFactory; @Nullable private Object getterTarget; private Schema schema; + private boolean collectionHandledByGetter = false; Builder(Schema schema) { this.schema = schema; @@ -554,6 +555,12 @@ public Builder withFieldValueGetters( return this; } + /** The FieldValueGetters will handle the conversion for Arrays, Maps and Rows. */ + public Builder withFieldValueGettersHandleCollections(boolean collectionHandledByGetter) { + this.collectionHandledByGetter = collectionHandledByGetter; + return this; + } + private List verify(Schema schema, List values) { List verifiedValues = Lists.newArrayListWithCapacity(values.size()); if (schema.getFieldCount() != values.size()) { @@ -754,7 +761,8 @@ public Row build() { return new RowWithStorage(schema, storageValues); } else if (fieldValueGetterFactory != null) { checkState(getterTarget != null, "getters require withGetterTarget."); - return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); + return new RowWithGetters( + schema, fieldValueGetterFactory, getterTarget, collectionHandledByGetter); } else { return new RowWithStorage(schema, Collections.emptyList()); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java index e50f818c1f67..c0fe11f331f2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java @@ -46,12 +46,18 @@ public class RowWithGetters extends Row { private final Map cachedLists = Maps.newHashMap(); private final Map cachedMaps = Maps.newHashMap(); + private final boolean collectionHandledByGetter; + RowWithGetters( - Schema schema, Factory> getterFactory, Object getterTarget) { + Schema schema, + Factory> getterFactory, + Object getterTarget, + boolean collectionHandledByGetter) { super(schema); this.fieldValueGetterFactory = getterFactory; this.getterTarget = getterTarget; this.getters = fieldValueGetterFactory.create(getterTarget.getClass(), schema); + this.collectionHandledByGetter = collectionHandledByGetter; } @Nullable @@ -87,15 +93,16 @@ private List getListValue(FieldType elementType, Object fieldValue) { @SuppressWarnings({"TypeParameterUnusedInFormals", "unchecked"}) private T getValue(FieldType type, Object fieldValue, @Nullable Integer cacheKey) { - if (type.getTypeName().equals(TypeName.ROW)) { - return (T) new RowWithGetters(type.getRowSchema(), fieldValueGetterFactory, fieldValue); - } else if (type.getTypeName().equals(TypeName.ARRAY)) { + if (type.getTypeName().equals(TypeName.ROW) && !collectionHandledByGetter) { + return (T) + new RowWithGetters(type.getRowSchema(), fieldValueGetterFactory, fieldValue, false); + } else if (type.getTypeName().equals(TypeName.ARRAY) && !collectionHandledByGetter) { return cacheKey != null ? (T) cachedLists.computeIfAbsent( cacheKey, i -> getListValue(type.getCollectionElementType(), fieldValue)) : (T) getListValue(type.getCollectionElementType(), fieldValue); - } else if (type.getTypeName().equals(TypeName.MAP)) { + } else if (type.getTypeName().equals(TypeName.MAP) && !collectionHandledByGetter) { Map map = (Map) fieldValue; return cacheKey != null ? (T) diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoFieldOverlay.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoFieldOverlay.java new file mode 100644 index 000000000000..6cddc2756358 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoFieldOverlay.java @@ -0,0 +1,525 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Message; +import com.google.protobuf.Timestamp; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.FieldValueGetter; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.joda.time.Instant; + +/** + * Protobuf ProtoFieldOverlay is the interface that each implementation needs to implement to handle + * a specific field types. + */ +@Experimental(Experimental.Kind.SCHEMAS) +public interface ProtoFieldOverlay extends FieldValueGetter { + + ValueT convertGetObject(FieldDescriptor fieldDescriptor, Object object); + + /** Convert the Row field and set it on the overlayed field of the message. */ + void set(Message.Builder object, ValueT value); + + Object convertSetObject(FieldDescriptor fieldDescriptor, Object value); + + /** Return the Beam Schema Field of this overlayed field. */ + Schema.Field getSchemaField(); + + abstract class ProtoFieldOverlayBase implements ProtoFieldOverlay { + + protected int number; + + private Schema.Field field; + + FieldDescriptor getFieldDescriptor(Message message) { + return message.getDescriptorForType().findFieldByNumber(number); + } + + FieldDescriptor getFieldDescriptor(Message.Builder message) { + return message.getDescriptorForType().findFieldByNumber(number); + } + + protected void setField(Schema.Field field) { + this.field = field; + } + + ProtoFieldOverlayBase(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + // this.fieldDescriptor = fieldDescriptor; + this.number = fieldDescriptor.getNumber(); + } + + @Override + public String name() { + return field.getName(); + } + + @Override + public Schema.Field getSchemaField() { + return field; + } + } + + /** Overlay for Protobuf primitive types. Primitive values are just passed through. */ + class PrimitiveOverlay extends ProtoFieldOverlayBase { + PrimitiveOverlay(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + // this.fieldDescriptor = fieldDescriptor; + super(protoSchema, fieldDescriptor); + setField( + Schema.Field.of( + fieldDescriptor.getName(), + ProtoSchema.convertType(fieldDescriptor.getType()) + .withMetadata(protoSchema.convertOptions(fieldDescriptor)))); + } + + @Override + public Object get(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertGetObject(fieldDescriptor, message.getField(fieldDescriptor)); + } + + @Override + public Object convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + return object; + } + + @Override + public void set(Message.Builder message, Object value) { + message.setField(getFieldDescriptor(message), value); + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** + * Overlay for Bytes. Protobuf Bytes are natively represented as ByteStrings that requires special + * handling for byte[] of size 0. + */ + class BytesOverlay extends PrimitiveOverlay { + BytesOverlay(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + super(protoSchema, fieldDescriptor); + } + + @Override + public Object convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + // return object; + return ((ByteString) object).toByteArray(); + } + + @Override + public void set(Message.Builder message, Object value) { + if (value != null && ((byte[]) value).length > 0) { + // Protobuf messages BYTES doesn't like empty bytes?! + FieldDescriptor fieldDescriptor = message.getDescriptorForType().findFieldByNumber(number); + message.setField(fieldDescriptor, convertSetObject(fieldDescriptor, value)); + } + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + if (value != null) { + return ByteString.copyFrom((byte[]) value); + } + return null; + } + } + + /** + * Overlay handler for the Well Known Type "Wrapper". These wrappers make it possible to have + * nullable primitives. + */ + class WrapperOverlay extends ProtoFieldOverlayBase { + private ProtoFieldOverlay value; + + WrapperOverlay(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + super(protoSchema, fieldDescriptor); + FieldDescriptor valueDescriptor = fieldDescriptor.getMessageType().findFieldByName("value"); + this.value = protoSchema.createFieldLayer(valueDescriptor, false); + setField( + Schema.Field.of( + fieldDescriptor.getName(), value.getSchemaField().getType().withNullable(true))); + } + + @Override + public ValueT get(Message message) { + if (message.hasField(getFieldDescriptor(message))) { + Message wrapper = (Message) message.getField(getFieldDescriptor(message)); + return (ValueT) value.get(wrapper); + } + return null; + } + + @Override + public ValueT convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + return (ValueT) object; + } + + @Override + public void set(Message.Builder message, ValueT value) { + if (value != null) { + DynamicMessage.Builder builder = + DynamicMessage.newBuilder(getFieldDescriptor(message).getMessageType()); + this.value.set(builder, value); + message.setField(getFieldDescriptor(message), builder.build()); + } + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** + * Overlay handler for the Well Known Type "Timestamp". This wrappers converts from a single Row + * DATETIME and a protobuf "Timestamp" messsage. + */ + class TimestampOverlay extends ProtoFieldOverlayBase { + TimestampOverlay(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + super(protoSchema, fieldDescriptor); + setField( + Schema.Field.of( + fieldDescriptor.getName(), + Schema.FieldType.DATETIME.withMetadata( + protoSchema.convertOptions(fieldDescriptor))) + .withNullable(true)); + } + + @Override + public Instant get(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + if (message.hasField(fieldDescriptor)) { + Message wrapper = (Message) message.getField(fieldDescriptor); + return convertGetObject(fieldDescriptor, wrapper); + } + return null; + } + + @Override + public Instant convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + Message timestamp = (Message) object; + Descriptors.Descriptor timestampFieldDescriptor = timestamp.getDescriptorForType(); + return new Instant( + (Long) timestamp.getField(timestampFieldDescriptor.findFieldByName("seconds")) * 1000 + + (Integer) timestamp.getField(timestampFieldDescriptor.findFieldByName("nanos")) + / 1000000); + } + + @Override + public void set(Message.Builder message, Instant value) { + if (value != null) { + long totalMillis = value.getMillis(); + long seconds = totalMillis / 1000; + int ns = (int) (totalMillis % 1000 * 1000000); + Timestamp timestamp = Timestamp.newBuilder().setSeconds(seconds).setNanos(ns).build(); + message.setField(getFieldDescriptor(message), timestamp); + } + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** This overlay converts a nested Message into a nested Row. */ + class MessageOverlay extends ProtoFieldOverlayBase { + private final SerializableFunction toRowFunction; + private final SerializableFunction fromRowFunction; + + MessageOverlay(ProtoSchema rootProtoSchema, FieldDescriptor fieldDescriptor) { + super(rootProtoSchema, fieldDescriptor); + + ProtoSchema protoSchema = + ProtoSchema.newBuilder(rootProtoSchema).forDescriptor(fieldDescriptor.getMessageType()); + SchemaCoder schemaCoder = protoSchema.getSchemaCoder(); + toRowFunction = schemaCoder.getToRowFunction(); + fromRowFunction = schemaCoder.getFromRowFunction(); + setField( + Schema.Field.of( + fieldDescriptor.getName(), + Schema.FieldType.row(protoSchema.getSchema()) + .withMetadata(protoSchema.convertOptions(fieldDescriptor)) + .withNullable(true))); + } + + @Override + public Object get(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + if (message.hasField(fieldDescriptor)) { + return convertGetObject(fieldDescriptor, message.getField(fieldDescriptor)); + } + return null; + } + + @Override + public Object convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + return toRowFunction.apply(object); + } + + @Override + public void set(Message.Builder message, Object value) { + if (value != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertSetObject(fieldDescriptor, value)); + } + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + return fromRowFunction.apply(value); + } + } + + /** + * Proto has a well defined way of storing maps, by having a Message with two fields, named "key" + * and "value" in a repeatable field. This overlay translates between Row.map and the Protobuf + * map. + */ + class MapOverlay extends ProtoFieldOverlayBase { + private ProtoFieldOverlay key; + private ProtoFieldOverlay value; + + MapOverlay(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + super(protoSchema, fieldDescriptor); + key = + protoSchema.createFieldLayer( + fieldDescriptor.getMessageType().findFieldByName("key"), false); + value = + protoSchema.createFieldLayer( + fieldDescriptor.getMessageType().findFieldByName("value"), false); + setField( + Schema.Field.of( + fieldDescriptor.getName(), + Schema.FieldType.map( + key.getSchemaField().getType(), + value + .getSchemaField() + .getType() + .withMetadata(protoSchema.convertOptions(fieldDescriptor))) + .withNullable(true))); + } + + @Override + public Map get(Message message) { + List list = (List) message.getField(getFieldDescriptor(message)); + if (list.size() == 0) { + return null; + } + Map rowMap = new HashMap(); + list.forEach( + entry -> { + Message entryMessage = (Message) entry; + Descriptors.Descriptor entryDescriptor = entryMessage.getDescriptorForType(); + FieldDescriptor keyFieldDescriptor = entryDescriptor.findFieldByName("key"); + FieldDescriptor valueFieldDescriptor = entryDescriptor.findFieldByName("value"); + rowMap.put( + key.convertGetObject(keyFieldDescriptor, entryMessage.getField(keyFieldDescriptor)), + this.value.convertGetObject( + valueFieldDescriptor, entryMessage.getField(valueFieldDescriptor))); + }); + return rowMap; + } + + @Override + public Map convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + throw new RuntimeException("?"); + } + + @Override + public void set(Message.Builder message, Map map) { + if (map != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + List messageMap = new ArrayList(); + map.forEach( + (k, v) -> { + DynamicMessage.Builder builder = + DynamicMessage.newBuilder(fieldDescriptor.getMessageType()); + FieldDescriptor keyFieldDescriptor = + fieldDescriptor.getMessageType().findFieldByName("key"); + builder.setField( + keyFieldDescriptor, this.key.convertSetObject(keyFieldDescriptor, k)); + FieldDescriptor valueFieldDescriptor = + fieldDescriptor.getMessageType().findFieldByName("value"); + builder.setField( + valueFieldDescriptor, value.convertSetObject(valueFieldDescriptor, v)); + messageMap.add(builder.build()); + }); + message.setField(fieldDescriptor, messageMap); + } + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** + * This overlay handles repeatable fields. It handles the Array conversion, but delegates the + * conversion of the individual elements to an embedded overlay. + */ + class ArrayOverlay extends ProtoFieldOverlayBase { + private ProtoFieldOverlay element; + + ArrayOverlay(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + super(protoSchema, fieldDescriptor); + this.element = protoSchema.createFieldLayer(fieldDescriptor, false); + setField( + Schema.Field.of( + fieldDescriptor.getName(), + Schema.FieldType.array( + element + .getSchemaField() + .getType() + .withMetadata(protoSchema.convertOptions(fieldDescriptor))) + .withNullable(true))); + } + + @Override + public List get(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + List list = (List) message.getField(fieldDescriptor); + if (list.size() == 0) { + return null; + } + List arrayList = new ArrayList<>(); + list.forEach( + entry -> { + arrayList.add(element.convertGetObject(fieldDescriptor, entry)); + }); + return arrayList; + } + + @Override + public List convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + throw new RuntimeException("?"); + } + + @Override + public void set(Message.Builder message, List list) { + if (list != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + List targetList = new ArrayList(); + list.forEach( + (e) -> { + targetList.add(element.convertSetObject(fieldDescriptor, e)); + }); + message.setField(fieldDescriptor, targetList); + } + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** Enum overlay handles the conversion between a string and a ProtoBuf Enum. */ + class EnumOverlay extends ProtoFieldOverlayBase { + + EnumOverlay(ProtoSchema protoSchema, FieldDescriptor fieldDescriptor) { + super(protoSchema, fieldDescriptor); + setField( + Schema.Field.of( + fieldDescriptor.getName(), + Schema.FieldType.STRING.withMetadata(protoSchema.convertOptions(fieldDescriptor)))); + } + + @Override + public Object get(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertGetObject(fieldDescriptor, message.getField(fieldDescriptor)); + } + + @Override + public Object convertGetObject(FieldDescriptor fieldDescriptor, Object in) { + return in.toString(); + } + + @Override + public void set(Message.Builder message, Object value) { + // builder.setField(fieldDescriptor, + // convertSetObject(row.getString(fieldDescriptor.getName()))); + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertSetObject(fieldDescriptor, value)); + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + Descriptors.EnumDescriptor enumType = fieldDescriptor.getEnumType(); + return enumType.findValueByName(value.toString()); + } + } + + /** + * This overlay handles nullable fields. If a primitive field needs to be nullable this overlay is + * wrapped around the original overlay. + */ + class NullableOverlay extends ProtoFieldOverlayBase { + + private ProtoFieldOverlay fieldOverlay; + + NullableOverlay( + ProtoSchema protoSchema, + FieldDescriptor fieldDescriptor, + ProtoFieldOverlay fieldOverlay) { + super(protoSchema, fieldDescriptor); + this.fieldOverlay = fieldOverlay; + setField(fieldOverlay.getSchemaField().withNullable(true)); + } + + @Override + public Object get(Message message) { + if (message.hasField(getFieldDescriptor(message))) { + return fieldOverlay.get(message); + } + return null; + } + + @Override + public Object convertGetObject(FieldDescriptor fieldDescriptor, Object object) { + throw new RuntimeException("Value conversion should never be allowed in nullable fields"); + } + + @Override + public void set(Message.Builder message, Object value) { + if (value != null) { + fieldOverlay.set(message, value); + } + } + + @Override + public Object convertSetObject(FieldDescriptor fieldDescriptor, Object value) { + throw new RuntimeException("Value conversion should never be allowed in nullable fields"); + } + } +} diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchema.java new file mode 100644 index 000000000000..b05be3fddc96 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchema.java @@ -0,0 +1,569 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import com.google.protobuf.DescriptorProtos; +import com.google.protobuf.Descriptors; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Message; +import com.google.protobuf.UnknownFieldSet; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.Factory; +import org.apache.beam.sdk.schemas.FieldValueGetter; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; + +/** + * ProtoSchema is a top level anchor point. It makes sure it can recreate the complete schema and + * overlay with just the Message raw type or if it's a DynamicMessage with the serialised + * Descriptor. + * + *

ProtoDomain is an integral part of a ProtoSchema, it it contains all the information needed to + * iterpret and reconstruct messages. + * + *

    + *
  • Protobuf oneOf fields are mapped to nullable fields and flattened into the parent row. + *
  • Protobuf primitives are mapped to it's nullable counter part. + *
  • Protobuf maps are mapped to nullable maps, where empty maps are mapped to the null value. + *
  • Protobuf repeatables are mapped to nullable arrays, where empty arrays are mapped to the + * null value. + *
  • Protobuf enums are mapped to non-nullable string values. + *
  • Enum map to their string representation + *
+ * + *

Protobuf Well Know Types are handled by the Beam Schema system. Beam knows of the following + * Well Know Types: + * + *

    + *
  • google.protobuf.Timestamp maps to a nullable Field.DATATIME. + *
  • google.protobuf.StringValue maps to a nullable Field.STRING. + *
  • google.protobuf.DoubleValue maps to a nullable Field.DOUBLE. + *
  • google.protobuf.FloatValue maps to a nullable Field.FLOAT. + *
  • google.protobuf.BytesValue maps to a nullable Field.BYTES. + *
  • google.protobuf.BoolValue maps to a nullable Field.BOOL. + *
  • google.protobuf.Int64Value maps to a nullable Field.INT64. + *
  • google.protobuf.Int32Value maps to a nullable Field.INT32. + *
  • google.protobuf.UInt64Value maps to a nullable Field.INT64. + *
  • google.protobuf.UInt32Value maps to a nullable Field.INT32. + *
+ */ +@Experimental(Experimental.Kind.SCHEMAS) +public class ProtoSchema implements Serializable { + public static final long serialVersionUID = 1L; + private static final ProtoDomain STATIC_COMPILED_DOMAIN = new ProtoDomain(); + private static Map globalSchemaCache = new HashMap<>(); + private final Class rawType; + private final Map typeMapping; + private final ProtoDomain domain; + private transient Descriptors.Descriptor descriptor; + private transient SchemaCoder schemaCoder; + private transient Method fnNewBuilder; + private transient ArrayList getters; + + private ProtoSchema( + Class rawType, + Descriptors.Descriptor descriptor, + ProtoDomain domain, + Map overlayClasses) { + this.rawType = rawType; + this.descriptor = descriptor; + this.typeMapping = overlayClasses; + this.domain = domain; + init(); + } + + /** + * Create a new ProtoSchema Builder with the static compiled proto domain. This domain references + * only statically compiled Java Protobuf messages. + */ + public static Builder newBuilder() { + return new Builder(STATIC_COMPILED_DOMAIN); + } + + /** + * Create a new ProtoSchema Builder with a specific proto domain. It does not contain any messages + * of the static domain. A Domain is used for grouping different messages that belong together. + * Creating different schema builders with the same domain is safe. The resulting Protobuf + * messages created from the same domain with be equal. + */ + public static Builder newBuilder(ProtoDomain protoDomain) { + return new Builder(protoDomain); + } + + static Builder newBuilder(ProtoSchema protoSchema) { + return new Builder(protoSchema.domain).addTypeMapping(protoSchema.typeMapping); + } + + static ProtoSchema fromSchema(Schema schema) { + return globalSchemaCache.get(schema.getUUID()); + } + + static Schema.FieldType convertType(Descriptors.FieldDescriptor.Type type) { + switch (type) { + case DOUBLE: + return Schema.FieldType.DOUBLE; + case FLOAT: + return Schema.FieldType.FLOAT; + case INT64: + case UINT64: + case SINT64: + case FIXED64: + case SFIXED64: + return Schema.FieldType.INT64; + case INT32: + case FIXED32: + case UINT32: + case SFIXED32: + case SINT32: + return Schema.FieldType.INT32; + case BOOL: + return Schema.FieldType.BOOLEAN; + case STRING: + case ENUM: + return Schema.FieldType.STRING; + case BYTES: + return Schema.FieldType.BYTES; + case MESSAGE: + case GROUP: + break; + } + throw new RuntimeException("Field type not matched."); + } + + Map convertOptions(Descriptors.FieldDescriptor protoField) { + Map metadata = new HashMap<>(); + DescriptorProtos.FieldOptions options = protoField.getOptions(); + options + .getAllFields() + .forEach( + (fd, value) -> { + String name = fd.getFullName(); + if (name.startsWith("google.protobuf.FieldOptions")) { + name = fd.getName(); + } + if (value instanceof Message) { + Message message = (Message) value; + Descriptors.Descriptor descriptorForType = message.getDescriptorForType(); + List fields = descriptorForType.getFields(); + for (Descriptors.FieldDescriptor field : fields) { + metadata.put( + name + "." + field.getName(), + message.getField(field).toString().getBytes(StandardCharsets.UTF_8)); + } + } else { + metadata.put(name, value.toString().getBytes(StandardCharsets.UTF_8)); + } + }); + + options + .getUnknownFields() + .asMap() + .forEach( + (ix, ufs) -> { + Descriptors.FieldDescriptor fieldOptionById = domain.getFieldOptionById(ix); + if (fieldOptionById != null) { + String name = fieldOptionById.getFullName(); + decodeUnknownOptionValue(metadata, name, fieldOptionById, ufs); + } + }); + return metadata; + } + + private void decodeUnknownOptionValue( + Map metadata, + String name, + Descriptors.FieldDescriptor fieldDescriptor, + UnknownFieldSet.Field value) { + + switch (fieldDescriptor.getType()) { + case MESSAGE: + break; + case FIXED64: + metadata.put( + name, + value.getFixed64List().stream() + .map( + l -> { + if (l >= 0) { + return Long.toString(l); + } else { + return BigInteger.valueOf(l & 0x7FFFFFFFFFFFFFFFL).setBit(63).toString(); + } + }) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case FIXED32: + break; + case BOOL: + metadata.put( + name, + value.getVarintList().stream() + .map(l -> Boolean.valueOf(l.intValue() > 0).toString()) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case ENUM: + metadata.put( + name, + value.getVarintList().stream() + .map(l -> fieldDescriptor.getEnumType().findValueByNumber(l.intValue()).getName()) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case STRING: + metadata.put( + name, + value.getLengthDelimitedList().stream() + .map(l -> (l.toString(StandardCharsets.UTF_8))) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case INT32: + case INT64: + case SINT32: + case SINT64: + case UINT32: + metadata.put( + name, + value.getVarintList().stream() + .map(l -> l.toString()) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case UINT64: + metadata.put( + name, + value.getVarintList().stream() + .map( + l -> { + if (l >= 0) { + return Long.toString(l); + } else { + return BigInteger.valueOf(l & 0x7FFFFFFFFFFFFFFFL).setBit(63).toString(); + } + }) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case DOUBLE: + metadata.put( + name, + value.getFixed64List().stream() + .map(l -> String.valueOf(Double.longBitsToDouble(l))) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case FLOAT: + metadata.put( + name, + value.getFixed32List().stream() + .map(l -> String.valueOf(Float.intBitsToFloat(l))) + .collect(Collectors.joining("\n")) + .getBytes(StandardCharsets.UTF_8)); + break; + case BYTES: + if (value.getLengthDelimitedList().size() > 0) { + metadata.put(name, value.getLengthDelimitedList().get(0).toByteArray()); + } + break; + case SFIXED32: + break; + case SFIXED64: + break; + case GROUP: + break; + default: + throw new IllegalStateException( + "Conversion of Unknown Field for type " + + fieldDescriptor.getType().toString() + + " not implemented"); + } + } + + private static boolean isMap(Descriptors.FieldDescriptor protoField) { + return protoField.getType() == Descriptors.FieldDescriptor.Type.MESSAGE + && protoField.getMessageType().getFullName().endsWith("Entry") + && (protoField.getMessageType().findFieldByName("key") != null) + && (protoField.getMessageType().findFieldByName("value") != null); + } + + ProtoFieldOverlay createFieldLayer(Descriptors.FieldDescriptor protoField, boolean nullable) { + Descriptors.FieldDescriptor.Type fieldDescriptor = protoField.getType(); + ProtoFieldOverlay fieldOverlay; + switch (fieldDescriptor) { + case DOUBLE: + case FLOAT: + case INT64: + case UINT64: + case SINT64: + case FIXED64: + case SFIXED64: + case INT32: + case FIXED32: + case UINT32: + case SFIXED32: + case SINT32: + case BOOL: + case STRING: + fieldOverlay = new ProtoFieldOverlay.PrimitiveOverlay(this, protoField); + break; + case BYTES: + fieldOverlay = new ProtoFieldOverlay.BytesOverlay(this, protoField); + break; + case ENUM: + fieldOverlay = new ProtoFieldOverlay.EnumOverlay(this, protoField); + break; + case MESSAGE: + String fullName = protoField.getMessageType().getFullName(); + if (typeMapping.containsKey(fullName)) { + Class aClass = typeMapping.get(fullName); + try { + Constructor constructor = aClass.getConstructor(Descriptors.FieldDescriptor.class); + return (ProtoFieldOverlay) constructor.newInstance(protoField); + } catch (NoSuchMethodException e) { + throw new RuntimeException("Unable to find constructor for Overlay mapper."); + } catch (IllegalAccessException | InstantiationException | InvocationTargetException e) { + throw new RuntimeException("Unable to invoke Overlay mapper."); + } + } + switch (fullName) { + case "google.protobuf.Timestamp": + return new ProtoFieldOverlay.TimestampOverlay(this, protoField); + case "google.protobuf.StringValue": + case "google.protobuf.DoubleValue": + case "google.protobuf.FloatValue": + case "google.protobuf.BoolValue": + case "google.protobuf.Int64Value": + case "google.protobuf.Int32Value": + case "google.protobuf.UInt64Value": + case "google.protobuf.UInt32Value": + case "google.protobuf.BytesValue": + return new ProtoFieldOverlay.WrapperOverlay(this, protoField); + case "google.protobuf.Duration": + default: + if (isMap(protoField)) { + return new ProtoFieldOverlay.MapOverlay(this, protoField); + } else { + return new ProtoFieldOverlay.MessageOverlay(this, protoField); + } + } + case GROUP: + default: + throw new RuntimeException("Field type not matched."); + } + if (nullable) { + return new ProtoFieldOverlay.NullableOverlay(this, protoField, fieldOverlay); + } + return fieldOverlay; + } + + private ArrayList createFieldLayer(Descriptors.Descriptor descriptor) { + // Oneof fields are nullable, even as they are primitive or enums + List oneofMap = + descriptor.getOneofs().stream() + .flatMap(oneofDescriptor -> oneofDescriptor.getFields().stream()) + .collect(Collectors.toList()); + + ArrayList fieldOverlays = new ArrayList<>(); + Iterator protoFields = descriptor.getFields().iterator(); + for (int i = 0; i < descriptor.getFields().size(); i++) { + Descriptors.FieldDescriptor protoField = protoFields.next(); + if (protoField.isRepeated() && !isMap(protoField)) { + fieldOverlays.add(new ProtoFieldOverlay.ArrayOverlay(this, protoField)); + } else { + fieldOverlays.add(createFieldLayer(protoField, oneofMap.contains(protoField))); + } + } + return fieldOverlays; + } + + private void init() { + this.getters = createFieldLayer(descriptor); + + Schema.Builder builder = Schema.builder(); + for (ProtoFieldOverlay field : getters) { + builder.addField(field.getSchemaField()); + } + + Schema schema = builder.build(); + schema.setUUID(UUID.randomUUID()); + schemaCoder = + SchemaCoder.of( + schema, + TypeDescriptor.of(rawType), + new MessageToRowFunction(), + new RowToMessageFunction()); + + globalSchemaCache.put(schema.getUUID(), this); + try { + if (DynamicMessage.class.equals(rawType)) { + this.fnNewBuilder = rawType.getMethod("newBuilder", Descriptors.Descriptor.class); + } else { + this.fnNewBuilder = rawType.getMethod("newBuilder"); + } + } catch (NoSuchMethodException e) { + } + } + + public Schema getSchema() { + return this.schemaCoder.getSchema(); + } + + public SchemaCoder getSchemaCoder() { + return schemaCoder; + } + + public SerializableFunction getToRowFunction() { + return schemaCoder.getToRowFunction(); + } + + public SerializableFunction getFromRowFunction() { + return schemaCoder.getFromRowFunction(); + } + + private void writeObject(ObjectOutputStream oos) throws IOException { + oos.defaultWriteObject(); + if (DynamicMessage.class.equals(this.rawType)) { + if (this.descriptor == null) { + throw new RuntimeException("DynamicMessages require provider a Descriptor to the coder."); + } + oos.writeUTF(descriptor.getFullName()); + } + } + + private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { + ois.defaultReadObject(); + if (DynamicMessage.class.equals(rawType)) { + descriptor = domain.getDescriptor(ois.readUTF()); + } else { + descriptor = ProtobufUtil.getDescriptorForClass(rawType); + } + init(); + } + + public ProtoDomain getDomain() { + return domain; + } + + public static class Builder implements Serializable { + + private ProtoDomain domain; + private Map mappings = new HashMap<>(); + + public Builder(ProtoDomain domain) { + this.domain = domain; + } + + public Builder addTypeMapping(Map mappings) { + this.mappings.putAll(mappings); + return this; + } + + public Builder addTypeMapping(String message, Class mappingClass) { + this.mappings.put(message, mappingClass); + return this; + } + + public ProtoSchema forType(Class rawType) { + return new ProtoSchema( + rawType, + ProtobufUtil.getDescriptorForClass(rawType), + domain, + ImmutableMap.copyOf(mappings)); + } + + public ProtoSchema forDescriptor(Descriptors.Descriptor descriptor) { + return new ProtoSchema( + DynamicMessage.class, descriptor, domain, ImmutableMap.copyOf(mappings)); + } + } + + /** Overlay. */ + public static class ProtoOverlayFactory implements Factory> { + + public ProtoOverlayFactory() {} + + @Override + public List create(Class clazz, Schema schema) { + return ProtoSchema.fromSchema(schema).getters; + } + } + + private class MessageToRowFunction implements SerializableFunction { + + private MessageToRowFunction() {} + + @Override + public Row apply(Message input) { + return Row.withSchema(schemaCoder.getSchema()) + .withFieldValueGettersHandleCollections(true) + .withFieldValueGetters(new ProtoOverlayFactory(), input) + .build(); + } + } + + private class RowToMessageFunction implements SerializableFunction { + + private RowToMessageFunction() {} + + @Override + public T apply(Row input) { + Message.Builder builder; + try { + if (DynamicMessage.class.equals(rawType)) { + builder = (Message.Builder) fnNewBuilder.invoke(rawType, descriptor); + } else { + builder = (Message.Builder) fnNewBuilder.invoke(rawType); + } + } catch (IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException("Can't invoke newBuilder on the Protobuf message class.", e); + } + + Iterator values = input.getValues().iterator(); + Iterator overlayIterator = getters.iterator(); + + for (int i = 0; i < input.getValues().size(); i++) { + ProtoFieldOverlay getter = overlayIterator.next(); + Object value = values.next(); + getter.set(builder, value); + } + return (T) builder.build(); + } + } +} diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaProvider.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaProvider.java new file mode 100644 index 000000000000..45503a1c6628 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaProvider.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import com.google.protobuf.DynamicMessage; +import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaProvider; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Schema provider for Protobuf messages. The provider is able to handle pre compiled Message file + * without external help. For Dynamic Messages a Descriptor needs to be registered up front on a + * specific URN. + * + *

It's possible to inherit this class for a specific implementation that communicates with an + * external registry that maps those URN's with Descriptors. + */ +@Experimental(Experimental.Kind.SCHEMAS) +public class ProtoSchemaProvider implements SchemaProvider { + private static final Logger LOG = LoggerFactory.getLogger(ProtoSchemaProvider.class); + + private final ProtoSchema.Builder protoSchemaBuilder; + + public ProtoSchemaProvider() { + this.protoSchemaBuilder = ProtoSchema.newBuilder(); + } + + public ProtoSchemaProvider(ProtoSchema.Builder protoSchemaBuilder) { + this.protoSchemaBuilder = protoSchemaBuilder; + } + + @Override + public Schema schemaFor(TypeDescriptor typeDescriptor) { + checkForDynamicType(typeDescriptor); + return protoSchemaBuilder.forType(typeDescriptor.getRawType()).getSchema(); + } + + @Nullable + @Override + public SerializableFunction toRowFunction(TypeDescriptor typeDescriptor) { + checkForDynamicType(typeDescriptor); + return protoSchemaBuilder + .forType(typeDescriptor.getRawType()) + .getSchemaCoder() + .getToRowFunction(); + } + + @Override + public SerializableFunction fromRowFunction(TypeDescriptor typeDescriptor) { + checkForDynamicType(typeDescriptor); + return protoSchemaBuilder + .forType(typeDescriptor.getRawType()) + .getSchemaCoder() + .getFromRowFunction(); + } + + private void checkForDynamicType(TypeDescriptor typeDescriptor) { + if (typeDescriptor.getRawType().equals(DynamicMessage.class)) { + throw new RuntimeException( + "DynamicMessage is not allowed for the standard ProtoSchemaProvider, use ProtoSchema build instead."); + } + } +} diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTest.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTest.java new file mode 100644 index 000000000000..b49ac5d5d5d6 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTest.java @@ -0,0 +1,577 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import static org.junit.Assert.assertEquals; + +import com.google.protobuf.Descriptors; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Message; +import java.io.IOException; +import java.util.Objects; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.util.SerializableUtils; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Collection of standard tests for Protobuf Schema support. */ +@RunWith(JUnit4.class) +public class ProtoSchemaTest { + + private static final Schema PRIMITIVE_SCHEMA = + Schema.builder() + .addDoubleField("primitive_double") + .addFloatField("primitive_float") + .addInt32Field("primitive_int32") + .addInt64Field("primitive_int64") + .addInt32Field("primitive_uint32") + .addInt64Field("primitive_uint64") + .addInt32Field("primitive_sint32") + .addInt64Field("primitive_sint64") + .addInt32Field("primitive_fixed32") + .addInt64Field("primitive_fixed64") + .addInt32Field("primitive_sfixed32") + .addInt64Field("primitive_sfixed64") + .addBooleanField("primitive_bool") + .addStringField("primitive_string") + .addByteArrayField("primitive_bytes") + .build(); + static final Row PRIMITIVE_DEFAULT_ROW = + Row.withSchema(PRIMITIVE_SCHEMA) + .addValue((double) 0) + .addValue((float) 0) + .addValue(0) + .addValue(0L) + .addValue(0) + .addValue(0L) + .addValue(0) + .addValue(0L) + .addValue(0) + .addValue(0L) + .addValue(0) + .addValue(0L) + .addValue(Boolean.FALSE) + .addValue("") + .addValue(new byte[] {}) + .build(); + static final Schema MESSAGE_SCHEMA = + Schema.builder() + .addField("message", Schema.FieldType.row(PRIMITIVE_SCHEMA).withNullable(true)) + .addField( + "repeated_message", + Schema.FieldType.array( + // TODO: are the nullable's correct + Schema.FieldType.row(PRIMITIVE_SCHEMA).withNullable(true)) + .withNullable(true)) + .build(); + private static final Row MESSAGE_DEFAULT_ROW = + Row.withSchema(MESSAGE_SCHEMA).addValue(null).addValue(null).build(); + private static final Schema REPEAT_PRIMITIVE_SCHEMA = + Schema.builder() + .addField( + "repeated_double", Schema.FieldType.array(Schema.FieldType.DOUBLE).withNullable(true)) + .addField( + "repeated_float", Schema.FieldType.array(Schema.FieldType.FLOAT).withNullable(true)) + .addField( + "repeated_int32", Schema.FieldType.array(Schema.FieldType.INT32).withNullable(true)) + .addField( + "repeated_int64", Schema.FieldType.array(Schema.FieldType.INT64).withNullable(true)) + .addField( + "repeated_uint32", Schema.FieldType.array(Schema.FieldType.INT32).withNullable(true)) + .addField( + "repeated_uint64", Schema.FieldType.array(Schema.FieldType.INT64).withNullable(true)) + .addField( + "repeated_sint32", Schema.FieldType.array(Schema.FieldType.INT32).withNullable(true)) + .addField( + "repeated_sint64", Schema.FieldType.array(Schema.FieldType.INT64).withNullable(true)) + .addField( + "repeated_fixed32", Schema.FieldType.array(Schema.FieldType.INT32).withNullable(true)) + .addField( + "repeated_fixed64", Schema.FieldType.array(Schema.FieldType.INT64).withNullable(true)) + .addField( + "repeated_sfixed32", + Schema.FieldType.array(Schema.FieldType.INT32).withNullable(true)) + .addField( + "repeated_sfixed64", + Schema.FieldType.array(Schema.FieldType.INT64).withNullable(true)) + .addField( + "repeated_bool", Schema.FieldType.array(Schema.FieldType.BOOLEAN).withNullable(true)) + .addField( + "repeated_string", Schema.FieldType.array(Schema.FieldType.STRING).withNullable(true)) + .addField( + "repeated_bytes", Schema.FieldType.array(Schema.FieldType.BYTES).withNullable(true)) + .build(); + static final Row REPEAT_PRIMITIVE_DEFAULT_ROW = + Row.withSchema(REPEAT_PRIMITIVE_SCHEMA) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .build(); + private static final Schema COMPLEX_SCHEMA = + Schema.builder() + .addField("special_enum", Schema.FieldType.STRING) + .addField( + "repeated_enum", Schema.FieldType.array(Schema.FieldType.STRING).withNullable(true)) + .addField("oneof_int32", Schema.FieldType.INT32.withNullable(true)) + .addField("oneof_bool", Schema.FieldType.BOOLEAN.withNullable(true)) + .addField("oneof_string", Schema.FieldType.STRING.withNullable(true)) + .addField("oneof_primitive", Schema.FieldType.row(PRIMITIVE_SCHEMA).withNullable(true)) + .addField( + "x", + Schema.FieldType.map(Schema.FieldType.STRING, Schema.FieldType.INT32) + .withNullable(true)) + .addField( + "y", + Schema.FieldType.map(Schema.FieldType.STRING, Schema.FieldType.STRING) + .withNullable(true)) + .addField( + "z", + // TODO: null in map, does it make sense. + Schema.FieldType.map( + Schema.FieldType.STRING, + Schema.FieldType.row(PRIMITIVE_SCHEMA).withNullable(true)) + .withNullable(true)) + .addField("oneof_int64", Schema.FieldType.INT64.withNullable(true)) + .addField("oneof_double", Schema.FieldType.DOUBLE.withNullable(true)) + .build(); + static final Row COMPLEX_DEFAULT_ROW = + Row.withSchema(COMPLEX_SCHEMA) + .addValue("UNKNOWN") + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .build(); + private static final Schema WKT_MESSAGE_SCHEMA = + Schema.builder() + .addField("nullable_double", Schema.FieldType.DOUBLE.withNullable(true)) + .addField("nullable_float", Schema.FieldType.FLOAT.withNullable(true)) + .addField("nullable_int32", Schema.FieldType.INT32.withNullable(true)) + .addField("nullable_int64", Schema.FieldType.INT64.withNullable(true)) + .addField("nullable_uint32", Schema.FieldType.INT32.withNullable(true)) + .addField("nullable_uint64", Schema.FieldType.INT64.withNullable(true)) + // xxx + .addField("nullable_bool", Schema.FieldType.BOOLEAN.withNullable(true)) + .addField("nullable_string", Schema.FieldType.STRING.withNullable(true)) + .addField("nullable_bytes", Schema.FieldType.BYTES.withNullable(true)) + // + .addField("wkt_timestamp", Schema.FieldType.DATETIME.withNullable(true)) + .build(); + static final Row WKT_MESSAGE_DEFAULT_ROW = + Row.withSchema(WKT_MESSAGE_SCHEMA) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .addValue(null) + .build(); + + @Test + public void testPrimitiveSchema() { + Schema schema = + new ProtoSchemaProvider() + .schemaFor(TypeDescriptor.of(Proto3SchemaMessages.Primitive.class)); + assertEquals(PRIMITIVE_SCHEMA, schema); + } + + @Test + public void testPrimitiveDefaultRow() { + SerializableFunction toRowFunction = + new ProtoSchemaProvider() + .toRowFunction(TypeDescriptor.of(Proto3SchemaMessages.Primitive.class)); + Row row = toRowFunction.apply(Proto3SchemaMessages.Primitive.newBuilder().build()); + assertEquals(PRIMITIVE_DEFAULT_ROW, row); + } + + @Test + public void testMessageSchema() { + Schema schema = + new ProtoSchemaProvider().schemaFor(TypeDescriptor.of(Proto3SchemaMessages.Message.class)); + assertEquals(MESSAGE_SCHEMA, schema); + } + + @Test + public void testMessageDefaultRow() { + SerializableFunction toRowFunction = + new ProtoSchemaProvider() + .toRowFunction(TypeDescriptor.of(Proto3SchemaMessages.Message.class)); + Row row = toRowFunction.apply(Proto3SchemaMessages.Message.newBuilder().build()); + assertEquals(MESSAGE_DEFAULT_ROW, row); + } + + @Test + public void testRepeatPrimitiveSchema() { + Schema schema = + new ProtoSchemaProvider() + .schemaFor(TypeDescriptor.of(Proto3SchemaMessages.RepeatPrimitive.class)); + assertEquals(REPEAT_PRIMITIVE_SCHEMA, schema); + } + + @Test + public void testRepeatPrimitiveDefaultRow() { + SerializableFunction toRowFunction = + new ProtoSchemaProvider() + .toRowFunction(TypeDescriptor.of(Proto3SchemaMessages.RepeatPrimitive.class)); + Row row = toRowFunction.apply(Proto3SchemaMessages.RepeatPrimitive.newBuilder().build()); + assertEquals(REPEAT_PRIMITIVE_DEFAULT_ROW, row); + } + + @Test + public void testComplexSchema() { + Schema schema = + new ProtoSchemaProvider().schemaFor(TypeDescriptor.of(Proto3SchemaMessages.Complex.class)); + assertEquals(COMPLEX_SCHEMA, schema); + } + + @Test + public void testComplexDefaultRow() { + SerializableFunction toRowFunction = + new ProtoSchemaProvider() + .toRowFunction(TypeDescriptor.of(Proto3SchemaMessages.Complex.class)); + Row row = toRowFunction.apply(Proto3SchemaMessages.Complex.newBuilder().build()); + assertEquals(COMPLEX_DEFAULT_ROW, row); + } + + @Test + public void testWktMessageSchema() { + Schema schema = + new ProtoSchemaProvider() + .schemaFor(TypeDescriptor.of(Proto3SchemaMessages.WktMessage.class)); + assertEquals(WKT_MESSAGE_SCHEMA, schema); + } + + @Test + public void testWktMessageDefaultRow() { + SerializableFunction toRowFunction = + new ProtoSchemaProvider() + .toRowFunction(TypeDescriptor.of(Proto3SchemaMessages.WktMessage.class)); + Row row = toRowFunction.apply(Proto3SchemaMessages.WktMessage.newBuilder().build()); + assertEquals(WKT_MESSAGE_DEFAULT_ROW, row); + } + + @Test + public void testCoder() throws Exception { + SchemaCoder schemaCoder = + ProtoSchema.newBuilder().forType(Proto3SchemaMessages.Complex.class).getSchemaCoder(); + RowCoder rowCoder = RowCoder.of(schemaCoder.getSchema()); + + byte[] schemaCoderBytes = SerializableUtils.serializeToByteArray(schemaCoder); + SchemaCoder schemaCoderCoded = + (SchemaCoder) SerializableUtils.deserializeFromByteArray(schemaCoderBytes, ""); + byte[] rowCoderBytes = SerializableUtils.serializeToByteArray(rowCoder); + RowCoder rowCoderCoded = + (RowCoder) SerializableUtils.deserializeFromByteArray(rowCoderBytes, ""); + + Proto3SchemaMessages.Complex message = + Proto3SchemaMessages.Complex.newBuilder() + .setOneofString("foobar") + .setSpecialEnum(Proto3SchemaMessages.Complex.EnumNested.FOO) + .build(); + + Row row = schemaCoder.getToRowFunction().apply(message); + byte[] rowBytes = CoderUtils.encodeToByteArray(rowCoder, row); + + Row rowCoded = CoderUtils.decodeFromByteArray(rowCoderCoded, rowBytes); + assertEquals(row, rowCoded); + + Message messageVerify = schemaCoder.getFromRowFunction().apply(rowCoded); + assertEquals(message, messageVerify); + + Message messageCoded = schemaCoderCoded.getFromRowFunction().apply(rowCoded); + assertEquals(message, messageCoded); + } + + @Test + public void testCoderOnDynamic() throws Exception { + Descriptors.Descriptor descriptor = Proto3SchemaMessages.Complex.getDescriptor(); + Descriptors.FieldDescriptor oneofString = descriptor.findFieldByName("oneof_string"); + Descriptors.FieldDescriptor specialEnum = descriptor.findFieldByName("special_enum"); + + SchemaCoder schemaCoder = + ProtoSchema.newBuilder(ProtoDomain.buildFrom(descriptor)) + .forDescriptor(descriptor) + .getSchemaCoder(); + RowCoder rowCoder = RowCoder.of(schemaCoder.getSchema()); + + byte[] schemaCoderBytes = SerializableUtils.serializeToByteArray(schemaCoder); + SchemaCoder schemaCoderCoded = + (SchemaCoder) SerializableUtils.deserializeFromByteArray(schemaCoderBytes, ""); + byte[] rowCoderBytes = SerializableUtils.serializeToByteArray(rowCoder); + RowCoder rowCoderCoded = + (RowCoder) SerializableUtils.deserializeFromByteArray(rowCoderBytes, ""); + + DynamicMessage message = + DynamicMessage.newBuilder(descriptor) + .setField(oneofString, "foobar") + .setField(specialEnum, Proto3SchemaMessages.Complex.EnumNested.FOO.getValueDescriptor()) + .build(); + + Row row = schemaCoder.getToRowFunction().apply(message); + byte[] rowBytes = CoderUtils.encodeToByteArray(rowCoder, row); + + Row rowCoded = CoderUtils.decodeFromByteArray(rowCoderCoded, rowBytes); + assertEquals(row, rowCoded); + + Message messageVerify = schemaCoder.getFromRowFunction().apply(rowCoded); + assertEquals(message, messageVerify); + + Message messageCoded = schemaCoderCoded.getFromRowFunction().apply(rowCoded); + Descriptors.FieldDescriptor oneofStringCoded = + messageCoded.getDescriptorForType().findFieldByName("oneof_string"); + Descriptors.FieldDescriptor specialEnumCoded = + messageCoded.getDescriptorForType().findFieldByName("special_enum"); + assertEquals(message.getField(oneofString), messageCoded.getField(oneofStringCoded)); + assertEquals( + ((Descriptors.EnumValueDescriptor) message.getField(specialEnum)).getFullName(), + ((Descriptors.EnumValueDescriptor) messageCoded.getField(specialEnumCoded)).getFullName()); + } + + @Test + public void testLogicalTypeRegistration() { + ProtoSchemaProvider protoSchemaProvider = + new ProtoSchemaProvider( + ProtoSchema.newBuilder() + .addTypeMapping("proto3_schema_messages.LatLng", LatLngOverlay.class)); + + SerializableFunction toRowFunction = + protoSchemaProvider.toRowFunction( + TypeDescriptor.of(Proto3SchemaMessages.LogicalTypes.class)); + Row row = + toRowFunction.apply( + Proto3SchemaMessages.LogicalTypes.newBuilder() + .setGps( + Proto3SchemaMessages.LatLng.newBuilder() + .setLatitude(1.2) + .setLongitude(3.5) + .build()) + .build()); + assertEquals(new LatLng(1.2, 3.5), row.getValue("gps")); + } + + /** Example Java type for testing LogicalTypes. */ + public static class LatLng { + double lat; + double lng; + + public LatLng(Double lat, Double lng) { + this.lat = lat; + this.lng = lng; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + LatLng latLng = (LatLng) o; + return Double.compare(latLng.lat, lat) == 0 && Double.compare(latLng.lng, lng) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(lat, lng); + } + } + + /** + * Example of a LogicalType converter. It will make sure that the type is convertible to a + * FieldType. + */ + public static class LatLngLogicalType implements Schema.LogicalType { + + static final Schema BASE_TYPE = + Schema.builder() + .addField("lat", Schema.FieldType.DOUBLE) + .addField("lng", Schema.FieldType.DOUBLE) + .build(); + + @Override + public String getIdentifier() { + return "LatLngLogicalType"; + } + + @Override + public Schema.FieldType getBaseType() { + return Schema.FieldType.row(LatLngLogicalType.BASE_TYPE); + } + + @Override + public Row toBaseType(LatLng input) { + return Row.withSchema(BASE_TYPE).addValue(input.lat).addValue(input.lng).build(); + } + + @Override + public LatLng toInputType(Row base) { + return new LatLng(base.getDouble("lat"), base.getDouble("lng")); + } + } + + /** Custom Protobuf field overlay that returns a custom LogicalType. */ + public static class LatLngOverlay implements ProtoFieldOverlay { + private Descriptors.FieldDescriptor fieldDescriptor; + private Descriptors.FieldDescriptor latitudeFieldDescriptor; + private Descriptors.FieldDescriptor longitudeFieldDescriptor; + + public LatLngOverlay(Descriptors.FieldDescriptor fieldDescriptor) { + this.fieldDescriptor = fieldDescriptor; + latitudeFieldDescriptor = fieldDescriptor.getMessageType().findFieldByName("latitude"); + longitudeFieldDescriptor = fieldDescriptor.getMessageType().findFieldByName("longitude"); + } + + @Override + public LatLng get(Message message) { + Message latLngMessage = (Message) message.getField(fieldDescriptor); + return new LatLng( + (double) latLngMessage.getField(latitudeFieldDescriptor), + (double) latLngMessage.getField(longitudeFieldDescriptor)); + } + + @Override + public String name() { + return fieldDescriptor.getName(); + } + + @Override + public LatLng convertGetObject(Descriptors.FieldDescriptor fieldDescriptor, Object object) { + return null; + } + + @Override + public void set(Message.Builder message, LatLng value) { + message.setField( + fieldDescriptor, + Proto3SchemaMessages.LatLng.newBuilder() + .setLongitude(value.lng) + .setLatitude(value.lat) + .build()); + } + + @Override + public Object convertSetObject(Descriptors.FieldDescriptor fieldDescriptor, Object value) { + return null; + } + + @Override + public Schema.Field getSchemaField() { + return Schema.Field.of( + fieldDescriptor.getName(), Schema.FieldType.logicalType(new LatLngLogicalType())); + } + } + + @Test + public void testMessageWithMetaSchema() { + Schema schema = + new ProtoSchemaProvider() + .schemaFor(TypeDescriptor.of(Proto3SchemaMessages.MessageWithMeta.class)); + Schema.Field fieldWithDescription = schema.getField("field_with_description"); + assertEquals( + "Cool field", + fieldWithDescription + .getType() + .getMetadataString("proto3_schema_messages.field_meta.description")); + assertEquals( + "0", + fieldWithDescription + .getType() + .getMetadataString("proto3_schema_messages.field_meta.foobar")); + assertEquals("", fieldWithDescription.getType().getMetadataString("deprecated")); + + Schema.Field fieldWithFoobar = schema.getField("field_with_foobar"); + assertEquals( + "", + fieldWithFoobar + .getType() + .getMetadataString("proto3_schema_messages.field_meta.description")); + assertEquals( + "42", + fieldWithFoobar.getType().getMetadataString("proto3_schema_messages.field_meta.foobar")); + assertEquals("", fieldWithFoobar.getType().getMetadataString("deprecated")); + + Schema.Field fieldWithDeprecation = schema.getField("field_with_deprecation"); + assertEquals( + "", + fieldWithDeprecation + .getType() + .getMetadataString("proto3_schema_messages.field_meta.description")); + assertEquals( + "", + fieldWithDeprecation + .getType() + .getMetadataString("proto3_schema_messages.field_meta.foobar")); + assertEquals("true", fieldWithDeprecation.getType().getMetadataString("deprecated")); + } + + @Test + public void testMessageWithMetaDynamicSchema() throws IOException { + ProtoDomain domain = ProtoDomain.buildFrom(getClass().getResourceAsStream("test_option_v1.pb")); + Descriptors.Descriptor descriptor = domain.getDescriptor("test.option.v1.MessageWithOptions"); + Schema schema = ProtoSchema.newBuilder(domain).forDescriptor(descriptor).getSchema(); + Schema.Field field; + field = schema.getField("field_with_fieldoption_double"); + assertEquals("100.1", field.getType().getMetadataString("test.option.v1.fieldoption_double")); + field = schema.getField("field_with_fieldoption_float"); + assertEquals("101.2", field.getType().getMetadataString("test.option.v1.fieldoption_float")); + field = schema.getField("field_with_fieldoption_int32"); + assertEquals("102", field.getType().getMetadataString("test.option.v1.fieldoption_int32")); + field = schema.getField("field_with_fieldoption_int64"); + assertEquals("103", field.getType().getMetadataString("test.option.v1.fieldoption_int64")); + field = schema.getField("field_with_fieldoption_bool"); + assertEquals("true", field.getType().getMetadataString("test.option.v1.fieldoption_bool")); + field = schema.getField("field_with_fieldoption_string"); + assertEquals("Oh yeah", field.getType().getMetadataString("test.option.v1.fieldoption_string")); + field = schema.getField("field_with_fieldoption_enum"); + assertEquals("ENUM1", field.getType().getMetadataString("test.option.v1.fieldoption_enum")); + field = schema.getField("field_with_fieldoption_repeated_string"); + assertEquals( + "Oh yeah\nOh no", + field.getType().getMetadataString("test.option.v1.fieldoption_repeated_string")); + } +} diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaValuesTest.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaValuesTest.java new file mode 100644 index 000000000000..7d9e8aa6ac7a --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaValuesTest.java @@ -0,0 +1,670 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTest.COMPLEX_DEFAULT_ROW; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTest.MESSAGE_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTest.PRIMITIVE_DEFAULT_ROW; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTest.REPEAT_PRIMITIVE_DEFAULT_ROW; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTest.WKT_MESSAGE_DEFAULT_ROW; +import static org.junit.Assert.assertEquals; + +import com.google.protobuf.BoolValue; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.FloatValue; +import com.google.protobuf.Int32Value; +import com.google.protobuf.Int64Value; +import com.google.protobuf.Message; +import com.google.protobuf.StringValue; +import com.google.protobuf.Timestamp; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +/** Collection of tests for values on Protobuf Messages and Rows. */ +@RunWith(Parameterized.class) +public class ProtoSchemaValuesTest { + + private final Message proto; + private final Row rowObject; + private SerializableFunction toRowFunction; + private SerializableFunction fromRowFunction; + + public ProtoSchemaValuesTest(String description, Message proto, Row rowObject) { + this.proto = proto; + this.rowObject = rowObject; + } + + @Parameterized.Parameters(name = "{0}") + public static Collection data() { + List data = new ArrayList<>(); + data.add( + new Object[] { + "primitive_int32", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt32(Integer.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_int32", Integer.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_int64", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt64(Long.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_int64", Long.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_uint32", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveUint32(Integer.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_uint32", Integer.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_uint64", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveUint64(Long.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_uint64", Long.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_sint32", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveSint32(Integer.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_sint32", Integer.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_sint64", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveSint64(Long.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_sint64", Long.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_fixed32", + Proto3SchemaMessages.Primitive.newBuilder() + .setPrimitiveFixed32(Integer.MAX_VALUE) + .build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_fixed32", Integer.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_fixed64", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveFixed64(Long.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_fixed64", Long.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_sfixed32", + Proto3SchemaMessages.Primitive.newBuilder() + .setPrimitiveSfixed32(Integer.MAX_VALUE) + .build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_sfixed32", Integer.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_sfixed64", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveSfixed64(Long.MAX_VALUE).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_sfixed64", Long.MAX_VALUE) + }); + data.add( + new Object[] { + "primitive_bool", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveBool(true).build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_bool", true) + }); + data.add( + new Object[] { + "primitive_string", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveString("lovely string").build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_string", "lovely string") + }); + data.add( + new Object[] { + "primitive_bytes", + Proto3SchemaMessages.Primitive.newBuilder() + .setPrimitiveBytes(ByteString.copyFrom(new byte[] {(byte) 0x0F})) + .build(), + change(PRIMITIVE_DEFAULT_ROW, "primitive_bytes", new byte[] {(byte) 0x0F}) + }); + + data.add( + new Object[] { + "repeated_double", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedDouble(Double.MAX_VALUE) + .addRepeatedDouble(0.0) + .addRepeatedDouble(Double.MIN_VALUE) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, + "repeated_double", + Arrays.asList(Double.MAX_VALUE, 0.0, Double.MIN_VALUE)) + }); + data.add( + new Object[] { + "repeated_float", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedFloat(Float.MAX_VALUE) + .addRepeatedFloat((float) 0.0) + .addRepeatedFloat(Float.MIN_VALUE) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, + "repeated_float", + Arrays.asList(Float.MAX_VALUE, (float) 0.0, Float.MIN_VALUE)) + }); + data.add( + new Object[] { + "repeated_int32", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedInt32(Integer.MAX_VALUE) + .addRepeatedInt32(0) + .addRepeatedInt32(Integer.MIN_VALUE) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, + "repeated_int32", + Arrays.asList(Integer.MAX_VALUE, 0, Integer.MIN_VALUE)) + }); + data.add( + new Object[] { + "repeated_int64", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedInt64(Long.MAX_VALUE) + .addRepeatedInt64(0L) + .addRepeatedInt64(Long.MIN_VALUE) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, + "repeated_int64", + Arrays.asList(Long.MAX_VALUE, 0L, Long.MIN_VALUE)) + }); + + data.add( + new Object[] { + "repeated_uint32", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedUint32(Integer.MAX_VALUE) + .addRepeatedUint32(0) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_uint32", Arrays.asList(Integer.MAX_VALUE, 0)) + }); + data.add( + new Object[] { + "repeated_uint64", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedUint64(Long.MAX_VALUE) + .addRepeatedUint64(0L) + .build(), + change(REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_uint64", Arrays.asList(Long.MAX_VALUE, 0L)) + }); + + data.add( + new Object[] { + "repeated_sint32", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedSint32(Integer.MAX_VALUE) + .addRepeatedSint32(0) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_sint32", Arrays.asList(Integer.MAX_VALUE, 0)) + }); + data.add( + new Object[] { + "repeated_sint64", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedSint64(Long.MAX_VALUE) + .addRepeatedSint64(0L) + .build(), + change(REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_sint64", Arrays.asList(Long.MAX_VALUE, 0L)) + }); + data.add( + new Object[] { + "repeated_fixed32", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedFixed32(Integer.MAX_VALUE) + .addRepeatedFixed32(0) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_fixed32", Arrays.asList(Integer.MAX_VALUE, 0)) + }); + data.add( + new Object[] { + "repeated_fixed64", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedFixed64(Long.MAX_VALUE) + .addRepeatedFixed64(0) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_fixed64", Arrays.asList(Long.MAX_VALUE, 0L)) + }); + data.add( + new Object[] { + "repeated_sfixed32", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedSfixed32(Integer.MAX_VALUE) + .addRepeatedSfixed32(0) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, + "repeated_sfixed32", + Arrays.asList(Integer.MAX_VALUE, 0)) + }); + data.add( + new Object[] { + "repeated_sfixed64", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedSfixed64(Long.MAX_VALUE) + .addRepeatedSfixed64(0L) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_sfixed64", Arrays.asList(Long.MAX_VALUE, 0L)) + }); + data.add( + new Object[] { + "repeated_bool", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedBool(true) + .addRepeatedBool(false) + .build(), + change(REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_bool", Arrays.asList(true, false)) + }); + data.add( + new Object[] { + "repeated_string", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedString("foo") + .addRepeatedString("bar") + .build(), + change(REPEAT_PRIMITIVE_DEFAULT_ROW, "repeated_string", Arrays.asList("foo", "bar")) + }); + data.add( + new Object[] { + "repeated_bytes", + Proto3SchemaMessages.RepeatPrimitive.newBuilder() + .addRepeatedBytes(ByteString.copyFrom(new byte[] {(byte) 0x0F})) + .build(), + change( + REPEAT_PRIMITIVE_DEFAULT_ROW, + "repeated_bytes", + Arrays.asList(new byte[][] {new byte[] {(byte) 0x0F}})) + }); + + data.add( + new Object[] { + "nullable_double_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableDouble(DoubleValue.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_double", 0.0) + }); + data.add( + new Object[] { + "nullable_double", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableDouble(DoubleValue.newBuilder().setValue(Double.MAX_VALUE).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_double", Double.MAX_VALUE) + }); + data.add( + new Object[] { + "nullable_float_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableFloat(FloatValue.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_float", (float) 0) + }); + data.add( + new Object[] { + "nullable_float", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableFloat(FloatValue.newBuilder().setValue(Float.MAX_VALUE).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_float", Float.MAX_VALUE) + }); + data.add( + new Object[] { + "nullable_int32", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableInt32(Int32Value.newBuilder().setValue(Integer.MAX_VALUE).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_int32", Integer.MAX_VALUE) + }); + data.add( + new Object[] { + "nullable_int32_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableInt32(Int32Value.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_int32", 0) + }); + data.add( + new Object[] { + "nullable_int64", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableInt64(Int64Value.newBuilder().setValue(Long.MAX_VALUE).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_int64", Long.MAX_VALUE) + }); + data.add( + new Object[] { + "nullable_int64_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableInt64(Int64Value.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_int64", 0L) + }); + data.add( + new Object[] { + "nullable_uint32", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableUint32(UInt32Value.newBuilder().setValue(Integer.MAX_VALUE).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_uint32", Integer.MAX_VALUE) + }); + data.add( + new Object[] { + "nullable_uint32_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableUint32(UInt32Value.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_uint32", 0) + }); + data.add( + new Object[] { + "nullable_uint64", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableUint64(UInt64Value.newBuilder().setValue(Long.MAX_VALUE).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_uint64", Long.MAX_VALUE) + }); + data.add( + new Object[] { + "nullable_uint64_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableUint64(UInt64Value.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_uint64", 0L) + }); + data.add( + new Object[] { + "nullable_bool", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableBool(BoolValue.newBuilder().setValue(Boolean.TRUE).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_bool", Boolean.TRUE) + }); + data.add( + new Object[] { + "nullable_bool_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableBool(BoolValue.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_bool", Boolean.FALSE) + }); + data.add( + new Object[] { + "nullable_string", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableString(StringValue.newBuilder().setValue("bar").build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_string", "bar") + }); + data.add( + new Object[] { + "nullable_string_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableString(StringValue.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_string", "") + }); + data.add( + new Object[] { + "nullable_bytes", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableBytes( + BytesValue.newBuilder() + .setValue(ByteString.copyFrom(new byte[] {(byte) 0x0F})) + .build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_bytes", new byte[] {(byte) 0x0F}) + }); + data.add( + new Object[] { + "nullable_bytes_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setNullableBytes(BytesValue.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "nullable_bytes", new byte[] {}) + }); + data.add( + new Object[] { + "wkt_timestamp", + Proto3SchemaMessages.WktMessage.newBuilder() + .setWktTimestamp( + Timestamp.newBuilder().setSeconds(1558680742).setNanos(123000000).build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "wkt_timestamp", new Instant(1558680742123L)) + }); + data.add( + new Object[] { + "wkt_timestamp_nvs", + Proto3SchemaMessages.WktMessage.newBuilder() + .setWktTimestamp(Timestamp.newBuilder().build()) + .build(), + change(WKT_MESSAGE_DEFAULT_ROW, "wkt_timestamp", new Instant(0)) + }); + + data.add( + new Object[] { + "special_enum", + Proto3SchemaMessages.Complex.newBuilder() + .setSpecialEnum(Proto3SchemaMessages.Complex.EnumNested.FOO) + .build(), + change(COMPLEX_DEFAULT_ROW, "special_enum", "FOO") + }); + data.add( + new Object[] { + "repeated_enum", + Proto3SchemaMessages.Complex.newBuilder() + .addRepeatedEnum(Proto3SchemaMessages.Complex.EnumNested.FOO) + .addRepeatedEnum(Proto3SchemaMessages.Complex.EnumNested.BAR) + .build(), + change(COMPLEX_DEFAULT_ROW, "repeated_enum", Arrays.asList("FOO", "BAR")) + }); + data.add( + new Object[] { + "oneof_int32", + Proto3SchemaMessages.Complex.newBuilder().setOneofInt32(42).build(), + change(COMPLEX_DEFAULT_ROW, "oneof_int32", 42) + }); + data.add( + new Object[] { + "oneof_bool", + Proto3SchemaMessages.Complex.newBuilder().setOneofBool(true).build(), + change(COMPLEX_DEFAULT_ROW, "oneof_bool", Boolean.TRUE) + }); + data.add( + new Object[] { + "oneof_string", + Proto3SchemaMessages.Complex.newBuilder().setOneofString("one_of_string").build(), + change(COMPLEX_DEFAULT_ROW, "oneof_string", "one_of_string") + }); + data.add( + new Object[] { + "oneof_primitive", + Proto3SchemaMessages.Complex.newBuilder() + .setOneofPrimitive( + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt32(42).build()) + .build(), + change( + COMPLEX_DEFAULT_ROW, + "oneof_primitive", + change(PRIMITIVE_DEFAULT_ROW, "primitive_int32", 42)) + }); + Map mapInt = new HashMap<>(); + mapInt.put("one", 1); + mapInt.put("two", 2); + data.add( + new Object[] { + "map_int", + Proto3SchemaMessages.Complex.newBuilder().putX("one", 1).putX("two", 2).build(), + change(COMPLEX_DEFAULT_ROW, "x", mapInt) + }); + Map mapString = new HashMap<>(); + mapString.put("one", "eno"); + mapString.put("two", "owt"); + data.add( + new Object[] { + "map_int", + Proto3SchemaMessages.Complex.newBuilder().putY("one", "eno").putY("two", "owt").build(), + change(COMPLEX_DEFAULT_ROW, "y", mapString) + }); + Map mapRow = new HashMap<>(); + mapRow.put("one", change(PRIMITIVE_DEFAULT_ROW, "primitive_int32", 1)); + mapRow.put("two", change(PRIMITIVE_DEFAULT_ROW, "primitive_string", "two")); + data.add( + new Object[] { + "map_row", + Proto3SchemaMessages.Complex.newBuilder() + .putZ("one", Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt32(1).build()) + .putZ( + "two", + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveString("two").build()) + .build(), + change(COMPLEX_DEFAULT_ROW, "z", mapRow) + }); + data.add( + new Object[] { + "subRow", + Proto3SchemaMessages.Message.newBuilder() + .setMessage( + Proto3SchemaMessages.Primitive.newBuilder() + .setPrimitiveString("we love strings") + .build()) + .build(), + Row.withSchema(MESSAGE_SCHEMA) + .addValue(change(PRIMITIVE_DEFAULT_ROW, "primitive_string", "we love strings")) + .addValue(null) + .build() + }); + data.add( + new Object[] { + "subRow+subArrayOfRow", + Proto3SchemaMessages.Message.newBuilder() + .setMessage(Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt32(42).build()) + .addRepeatedMessage( + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt32(69).build()) + .addRepeatedMessage( + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt32(70).build()) + .addRepeatedMessage( + Proto3SchemaMessages.Primitive.newBuilder().setPrimitiveInt32(71).build()) + .build(), + Row.withSchema(MESSAGE_SCHEMA) + .addValue(change(PRIMITIVE_DEFAULT_ROW, "primitive_int32", 42)) + .addArray( + change(PRIMITIVE_DEFAULT_ROW, "primitive_int32", 69), + change(PRIMITIVE_DEFAULT_ROW, "primitive_int32", 70), + change(PRIMITIVE_DEFAULT_ROW, "primitive_int32", 71)) + .build() + }); + return data; + } + + private static Row change(Row row, Object field, Object value) { + int index = -1; + List fields = row.getSchema().getFields(); + for (int i = 0; i < fields.size(); i++) { + Schema.Field f = fields.get(i); + if (f.getName().equals(field)) { + index = i; + break; + } + } + + Object[] objects = row.getValues().toArray(); + objects[index] = value; + return Row.withSchema(row.getSchema()).addValues(objects).build(); + } + + private void setup() { + ProtoSchemaProvider protoSchemaProvider = new ProtoSchemaProvider(); + TypeDescriptor typeDescriptor = TypeDescriptor.of(this.proto.getClass()); + + toRowFunction = protoSchemaProvider.toRowFunction(typeDescriptor); + fromRowFunction = protoSchemaProvider.fromRowFunction(typeDescriptor); + } + + private void setupForDynamicMessage() { + ProtoDomain domain = ProtoDomain.buildFrom(proto.getDescriptorForType()); + ProtoSchema protoSchema = + ProtoSchema.newBuilder(domain).forDescriptor(proto.getDescriptorForType()); + + toRowFunction = protoSchema.getSchemaCoder().getToRowFunction(); + fromRowFunction = protoSchema.getSchemaCoder().getFromRowFunction(); + } + + @Test + public void testRowAndBack() { + setup(); + Row row = toRowFunction.apply(this.proto); + Message message = fromRowFunction.apply(row); + assertEquals(proto, message); + } + + @Test + public void testToRow() { + setup(); + Row row = toRowFunction.apply(this.proto); + assertEquals(rowObject, row); + } + + @Test + public void testFromRow() { + setup(); + Message message = fromRowFunction.apply(this.rowObject); + assertEquals(this.proto, message); + } + + @Test + public void testToRowFromDynamicMessage() { + setupForDynamicMessage(); + Row row = toRowFunction.apply(DynamicMessage.newBuilder(this.proto).build()); + assertEquals(rowObject, row); + } + + @Test + public void testFromRowToDynamicMessage() { + setupForDynamicMessage(); + Message message = fromRowFunction.apply(this.rowObject); + assertEquals(this.proto, message); + } +} diff --git a/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto b/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto new file mode 100644 index 000000000000..68eaa1b2c7bc --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Protocol Buffer messages used for testing Proto3 Schema implementation. + */ + +syntax = "proto3"; + +package proto3_schema_messages; + +import "google/protobuf/timestamp.proto"; +import "google/protobuf/wrappers.proto"; +import "google/protobuf/descriptor.proto"; + +option java_package = "org.apache.beam.sdk.extensions.protobuf"; + +message MessageMeta { + string description = 1; + bool rewrite = 2; +} + +message FieldMeta { + string description = 1; + int32 foobar = 2; +} + +extend google.protobuf.MessageOptions { + MessageMeta message_meta = 66600666; +} + +extend google.protobuf.FieldOptions { + FieldMeta field_meta = 66600666; +} + +message Message { + Primitive message = 3; + repeated Primitive repeated_message = 4; +} + +message Primitive { + double primitive_double = 3; + float primitive_float = 4; + int32 primitive_int32 = 5; + int64 primitive_int64 = 6; + uint32 primitive_uint32 = 7; + uint64 primitive_uint64 = 8; + sint32 primitive_sint32 = 9; + sint64 primitive_sint64 = 10; + fixed32 primitive_fixed32 = 11; + fixed64 primitive_fixed64 = 12; + sfixed32 primitive_sfixed32 = 13; + sfixed64 primitive_sfixed64 = 14; + bool primitive_bool = 15; + string primitive_string = 16; + bytes primitive_bytes = 17; +} + +message RepeatPrimitive { + repeated double repeated_double = 1; + repeated float repeated_float = 2; + repeated int32 repeated_int32 = 3; + repeated int64 repeated_int64 = 4; + repeated uint32 repeated_uint32 = 5; + repeated uint64 repeated_uint64 = 6; + repeated sint32 repeated_sint32 = 7; + repeated sint64 repeated_sint64 = 8; + repeated fixed32 repeated_fixed32 = 9; + repeated fixed64 repeated_fixed64 = 10; + repeated sfixed32 repeated_sfixed32 = 11; + repeated sfixed64 repeated_sfixed64 = 12; + repeated bool repeated_bool = 13; + repeated string repeated_string = 14; + repeated bytes repeated_bytes = 15; +} + +message Complex { + enum EnumNested { + UNKNOWN = 0; + FOO = 1; + BAR = 2; + } + + EnumNested special_enum = 3; + repeated EnumNested repeated_enum = 4; + + oneof special_oneof { + int32 oneof_int32 = 5; + bool oneof_bool = 6; + string oneof_string = 7; + Primitive oneof_primitive = 8; + } + + map x = 9; + map y = 10; + map z = 11; + + oneof second_oneof { + int64 oneof_int64 = 12; + double oneof_double = 13; + } +} + +message WktMessage { + google.protobuf.DoubleValue nullable_double = 1; + google.protobuf.FloatValue nullable_float = 2; + google.protobuf.Int32Value nullable_int32 = 3; + google.protobuf.Int64Value nullable_int64 = 4; + google.protobuf.UInt32Value nullable_uint32 = 5; + google.protobuf.UInt64Value nullable_uint64 = 6; + google.protobuf.BoolValue nullable_bool = 13; + google.protobuf.StringValue nullable_string = 14; + google.protobuf.BytesValue nullable_bytes = 15; + google.protobuf.Timestamp wkt_timestamp = 16; +} + +message LatLng { + double latitude = 1; + double longitude = 2; +} + +message LogicalTypes { + LatLng gps = 1; +} + +message MessageWithMeta { + option (proto3_schema_messages.message_meta).description = "Cool field"; + option (proto3_schema_messages.message_meta).rewrite = true; + + string field_with_description = 1 [(proto3_schema_messages.field_meta).description = "Cool field"]; + string field_with_foobar = 2 [(proto3_schema_messages.field_meta).foobar = 42]; + string field_with_deprecation = 3 [deprecated = true]; +} + diff --git a/sdks/java/extensions/protobuf/src/test/resources/README.md b/sdks/java/extensions/protobuf/src/test/resources/README.md new file mode 100644 index 000000000000..79083f5142b0 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/resources/README.md @@ -0,0 +1,34 @@ + + +This recreates the proto descriptor set included in this resource directory. + +```bash +export PROTO_INCLUDE= +``` +Execute the following command to create the pb files, in the beam root folder: + +```bash +protoc \ + -Isdks/java/extensions/protobuf/src/test/resources/ \ + -I$PROTO_INCLUDE \ + --descriptor_set_out=sdks/java/extensions/protobuf/src/test/resources/org/apache/beam/sdk/extensions/protobuf/test_option_v1.pb \ + --include_imports \ + sdks/java/extensions/protobuf/src/test/resources/test/option/v1/simple.proto +``` diff --git a/sdks/java/extensions/protobuf/src/test/resources/org/apache/beam/sdk/extensions/protobuf/test_option_v1.pb b/sdks/java/extensions/protobuf/src/test/resources/org/apache/beam/sdk/extensions/protobuf/test_option_v1.pb new file mode 100644 index 0000000000000000000000000000000000000000..4e97ad02a15b162a8fa085c4d766d5f3decb5c93 GIT binary patch literal 18745 zcmc&*Yj7mjRc1ySY37bJYW2Q)wO;Rf{Mhv-p2hNJZM2lJ9;3C^l1AGjvECFgv}Rh; ztVh$6>Ct-SqNoIZ098;>@FxWn1tGy92@qa^Kpy55z=Q-t2(SDIuY#hADt@N;L2=GK zx9{A3O*cPm_Q#g`_C4P@_uO;8bI-j!?5A$9i~EkV-?6Xu`p&@FJ=(k4vfXCC-5WUl z3UAYAs0TekkMahaU2S*l71^8yVfBhW-b>@{NP4W%1{ZaGt<7b{X+zr3tBxlhc>M ziU$L`>$aV)G08z54}q+C z^lur4+GI?mY|zpSMcWQ5^=(`);9?qCfg4dt9GNM^XkgPG3XSVgCL0=@xsYB8=9JfK z6WDC-HCFPwM?cAO?!f8~IML-ydeAjC24u!rE@$GO`U}GWkMRLN5i-}0931!DvuxTo zfV8U1-@{HtLhCQHlSkclcVPE>eH$Z|R=@FYL?Gdq+fjRyf#wMx1T|hMz*sQ#YobiL5z>al^Ve|8vLMBL}2hZl^g z@~5(F9LK27u;TXJt!iVXwzITe)iQdCG2XDcUR&JOvhsOtbNl)=ZA=R6P{X*|;hA|Y z2TNdz0M=fvu7JpdzXkeC>eH;i`z+OJ>snr7*0;^I&6`?5V%{`sJ6jr)2piRUeeq^h zE6N^Acekr`ZAwC*S)gKCs;F-6Y-lBYmQ8bw^;I+rddvixz*0@3)CpzQ*8n~A@-oZu z=z=INtuHQB*BiC1?X}wG;yU3?(WGAC|GIC`oQ;d73 z;(U0DIaX+j|2fM}L~K7Im3$k^QHWCZd@>Y82*d+6NbMx2?jVTe!aE3r2pXw>XNc1& z?Ol(wAfUX0kUnZOC6&i;JLbY5x8Gp1A)wS8`~4Ya#5pI@)kCJk=^QfMjp0%z4?Ys< z0zL^rHAOW~BX!CwdbZ3-zhiIc4n+dbO1hAjP|EISnm z9%<-##@gN9(SWv8R67)S3nwom&!d5SkG4ShKux^Eb3>YxP!iJ9k#psu#w7K&M)Nltd47&GxWOdy+sMWIq54Bm@bAA^jb=kf4snoSV9Tq z`Xg-Wo^{`9c(oFhP(^IpqVfrpoa9|#jCQ}#>{zbL_bkZ%1yjd&HGH+~z0u#pPH>4s z$V2U3$8KP~=Ne44Y!-WLhz@9x3%kAZTx{R&+I?$aH|+Nv!MGb%x79eX+ymnz0xV^X zjCmG2-W2_+yxU^8^&)iA7uacz0zL!>Hkt=^^OeTYVDCBOoT{H&2!4S+%h+WHdgygF z*3VUS*qZyD#%9}6a|gF;(L1wXA`q^nN$*jeZBt;4g<#U}yhm_gF^OQ)DN$d9n4jZMw?~Y)TD$y}&5+ zVMhWRt>xNEwNYPtZxuEpus2&@+paduoz3mFjjA@LY}jv&=bz9nmEX#;lD`7ezm=T< zkMZLY*B&%pZTI<#JhZ6zkpy>A^n!z>JJ1K#8+6gv*@dnHoCwTe^;=LqdIwbvs|kzC zg{6(eE>RxpI(5-`n12>USKS}K$oNuFjbGR`A6h*K;6eYG@4NG69vZ6n{En1A;LW^N zfd2~`EC16No8pBEP7|E6U|G||8iE&(rTIrgSxN0(S;kzpFhRRxlO*xL1QyOKu!y<$ zO!1`OWE1yXj*ZEu6nUlx=On_C%B?yFFWjm(Hf!d_;<_os^|NfeW4(6l&w9Km?IZvK z9=!Q;4R3n;rl5S4<+#6e#)QYTZ_@K@e7R<>z=D7kKz?j&tyPy{O_blmCddi6Sfm?3 zh#o~v!j}@h-Pu^Gnp)PsaK^P<*@abzhacV?zisPXQC)r`!-~qr8p}zm({WyHSe-Ud zHiZV`O^eu)cIw^<<6&{+w2AW1WSAy*(4O=B%N#dCk<0v@FthyA8CDXhKH9uT{$I`e zdX`Q5HnD*@jhgA-4nE`{ z;NU}DWj{H_PDC7m0gcmL2UX2ag;+d7rhId2x*t~}T*zYZG~|K$P{FZqrGJ6yaIp_{GC zp}FykRMtg*Q%cT#DZ}z!Go;(`o;5h&aJi+dmNEG=w!>bs3rq`lQq+#mxX10-R*P3X z&fy`s@$|1SRq7*U}fVo4)hHqls26>7ufOt8| z^AuF>)Pw~cp1Nhf(4`s)0ZY_N*J9m>`zg9Rj zT&`A3d#__RG2qx2USQ`2a1nIX;hVjy_n#JjiHq41zO2y8%KfM360dU0U(2v*$~@NsxTJ_it7mPqlH0>=ygG4oS8KQ-3bJN)NmxMN}qSd2p zL$pNO{)*7_AR5@V$@HPUKvfUy7R3MPS@;UUO`1xOu99RFUZQz_O?VR_Jg{$zolxv^ z(;4_x0-Q0GvDJlWBx0h8e?y2`7!hoIktxfhvFkXU1PA^tq2ONE= z-Vi9!NFu1DUq9#?=DCF^P*@9zppu6}ps*Z5UKMPdh$;!$?>@oFk+5lumCPK5jj~qo zbE379qW!Rfr9;*VqE3XBTpfl5YwxgAzSxkAoro(*+aK~CZ$|^Caj#_W2w<@IE<5eF zSaNux$&$wXW0p)F1D(cb$>&kfVD)Y0k8EP~M2jV{&-%v9xjh^av{6k2l?)$<2reF8 z!o>55sFLda10}Y<9}F`woMdK5%lZCb*`ULB%rZX5@vU~U(p2Q{pAydhd+fYlomT~T zTkb?Je)9WZ*F6Ti?sGmMQ&bZoA+*>-zSj+@hc@GhYJQc}WX(Arj;Dq)M7)-isp8j2 zOD;wWrHqkU(k#7SCoN;~T44Rf&rI56@*AXPI%Z}>QW>o$&2WB`^o))6J(y z&uokyk#^$8lQ#eS7U|JP=z;MZu!nt-s6@0VArfQyw@J~Y>0cW%rbR)KpefB1eVQ~) zjM4;Sy2mcWj46G}`W;d=WBQ538vu_)Hc}UiYnS=SX*gHBR7squ&yb$H zdDiD1Z%jP7jflIHS?;r>r4Sc)lr#WQTPKr^)`1{9l`m4lohEEexso@ zipnSfPu4|H;z~Tqc>y>|bO~^>B2wWfxrO0iAtu1dS||VqD=*+Zk-8G_WEG@3N~Eg@ zz+?ku{SyQXTPTuIf`zj3Q2|8)3RozzP6D8;d4hmo;cYwslCv)BodiT#?Wp1F)jCm- z$!?T&P6!gLth1A%@m?c-e-dfQy5+0_;$?I})!)VDl$S#dRI?-uN>(g}g5fD3?vBGx zMdGqniNYr`AAz5|h-8%#gAXeyZU)lPysS^6WT zgt#%J%M7v_iPHdM*kunXF{Z?+s6-MXRn{SDw97~pRmeyk$x|UKkXRisjutx~IgYf+ zMAjZEF?a)!NoBN_*{*jf2YNCfP zu1?~(BUM6@HAWP7P-6N(CMIZ8rl_>lMO3_bsp>(@9Hq%wsp^WT`NQ0`JHfw2c!quX zbm&6@w|xko@W*^eFzTy=@_(OVx)|CW_zFT?)$kv;;iWx(a3j7Y;1~1wbq; z)6t04p$H{a{UtH)VMv};R zk>ERs&kLk)Akuy;641x3Sf5c|G~N(U@Ba~iI*4`_<-Y_hFzfdO)(;|9J01({dMVZ} z%5}!?6V~L9A=Xzi(O{I+62QQ)KM-IaM6j=p00YZb*uy-uB0_?)kMV~b%_PPRMuNz4 z2`FIGA0?rHEvxK8j4fUkGQL1Clyd{dcv&V92h927L>z7p4La-_%p}I2aEzsiH_*)) zllV4Z6J`uOekIc-yYZ)lH2o$bP1X)B9@|z-a8inZpFNWcB22>@bK^}KU*z)TNnpgp zi!5OL8DT_Dyx=-;b)1Wk$c--w7kzz%i;$Gz*GNyqOnu+`;EHMs<7e^Pr<_^l-2>kC z@~EJ@+wAcqP~yFxBg{ueVG`-3fF$<&Pzp$5zjTlgvp=5>5&}1cHw7PlIE6O_AEpEa zLqC!dG@Ye_A%7tSs36D`mLe!v`g%%GuyluAN(?5!m>(t7M@NQ}V9vC#VC*lZg#~Nx zu**r-N>2S4;XXFXT*1*$#D>S!ze1oRL(b#u6rSef{mB%b IR(I6@3vTl~-T(jq literal 0 HcmV?d00001 diff --git a/sdks/java/extensions/protobuf/src/test/resources/test/option/v1/option.proto b/sdks/java/extensions/protobuf/src/test/resources/test/option/v1/option.proto new file mode 100644 index 000000000000..ca40119dce3f --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/resources/test/option/v1/option.proto @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package test.option.v1; + +import "google/protobuf/descriptor.proto"; + +extend google.protobuf.FileOptions { + double fileoption_double = 66666700; + float fileoption_float = 66666701; + int32 fileoption_int32 = 66666702; + int64 fileoption_int64 = 66666703; + uint32 fileoption_uint32 = 66666704; + uint64 fileoption_uint64 = 66666705; + sint32 fileoption_sint32 = 66666706; + sint64 fileoption_sint64 = 66666707; + fixed32 fileoption_fixed32 = 66666708; + fixed64 fileoption_fixed64 = 66666709; + sfixed32 fileoption_sfixed32 = 66666710; + sfixed64 fileoption_sfixed64 = 66666711; + bool fileoption_bool = 66666712; + string fileoption_string = 66666713; + bytes fileoption_bytes = 66666714; + OptionMessage fileoption_message = 66666715; + OptionEnum fileoption_enum = 66666716; +} + +extend google.protobuf.MessageOptions { + double messageoption_double = 66666700; + float messageoption_float = 66666701; + int32 messageoption_int32 = 66666702; + int64 messageoption_int64 = 66666703; + uint32 messageoption_uint32 = 66666704; + uint64 messageoption_uint64 = 66666705; + sint32 messageoption_sint32 = 66666706; + sint64 messageoption_sint64 = 66666707; + fixed32 messageoption_fixed32 = 66666708; + fixed64 messageoption_fixed64 = 66666709; + sfixed32 messageoption_sfixed32 = 66666710; + sfixed64 messageoption_sfixed64 = 66666711; + bool messageoption_bool = 66666712; + string messageoption_string = 66666713; + bytes messageoption_bytes = 66666714; + OptionMessage messageoption_message = 66666715; + OptionEnum messageoption_enum = 66666716; + + repeated double messageoption_repeated_double = 66666800; + repeated float messageoption_repeated_float = 66666801; + repeated int32 messageoption_repeated_int32 = 66666802; + repeated int64 messageoption_repeated_int64 = 66666803; + repeated uint32 messageoption_repeated_uint32 = 66666804; + repeated uint64 messageoption_repeated_uint64 = 66666805; + repeated sint32 messageoption_repeated_sint32 = 66666806; + repeated sint64 messageoption_repeated_sint64 = 66666807; + repeated fixed32 messageoption_repeated_fixed32 = 66666808; + repeated fixed64 messageoption_repeated_fixed64 = 66666809; + repeated sfixed32 messageoption_repeated_sfixed32 = 66666810; + repeated sfixed64 messageoption_repeated_sfixed64 = 66666811; + repeated bool messageoption_repeated_bool = 66666812; + repeated string messageoption_repeated_string = 66666813; + repeated bytes messageoption_repeated_bytes = 66666814; + repeated OptionMessage messageoption_repeated_message = 66666815; + repeated OptionEnum messageoption_repeated_enum = 66666816; +} + +extend google.protobuf.FieldOptions { + double fieldoption_double = 66666700; + float fieldoption_float = 66666701; + int32 fieldoption_int32 = 66666702; + int64 fieldoption_int64 = 66666703; + uint32 fieldoption_uint32 = 66666704; + uint64 fieldoption_uint64 = 66666705; + sint32 fieldoption_sint32 = 66666706; + sint64 fieldoption_sint64 = 66666707; + fixed32 fieldoption_fixed32 = 66666708; + fixed64 fieldoption_fixed64 = 66666709; + sfixed32 fieldoption_sfixed32 = 66666710; + sfixed64 fieldoption_sfixed64 = 66666711; + bool fieldoption_bool = 66666712; + string fieldoption_string = 66666713; + bytes fieldoption_bytes = 66666714; + OptionMessage fieldoption_message = 66666715; + OptionEnum fieldoption_enum = 66666716; + + repeated double fieldoption_repeated_double = 66666800; + repeated float fieldoption_repeated_float = 66666801; + repeated int32 fieldoption_repeated_int32 = 66666802; + repeated int64 fieldoption_repeated_int64 = 66666803; + repeated uint32 fieldoption_repeated_uint32 = 66666804; + repeated uint64 fieldoption_repeated_uint64 = 66666805; + repeated sint32 fieldoption_repeated_sint32 = 66666806; + repeated sint64 fieldoption_repeated_sint64 = 66666807; + repeated fixed32 fieldoption_repeated_fixed32 = 66666808; + repeated fixed64 fieldoption_repeated_fixed64 = 66666809; + repeated sfixed32 fieldoption_repeated_sfixed32 = 66666810; + repeated sfixed64 fieldoption_repeated_sfixed64 = 66666811; + repeated bool fieldoption_repeated_bool = 66666812; + repeated string fieldoption_repeated_string = 66666813; + repeated bytes fieldoption_repeated_bytes = 66666814; + repeated OptionMessage fieldoption_repeated_message = 66666815; + repeated OptionEnum fieldoption_repeated_enum = 66666816; +} + +enum OptionEnum { + DEFAULT = 0; + ENUM1 = 1; + ENUM2 = 2; +} + +message OptionMessage { + string string = 1; + repeated string repeated_string = 2; + + int32 int32 = 3; + repeated int32 repeated_int32 = 4; + + int64 int64 = 5; + + OptionEnum test_enum = 6; +} \ No newline at end of file diff --git a/sdks/java/extensions/protobuf/src/test/resources/test/option/v1/simple.proto b/sdks/java/extensions/protobuf/src/test/resources/test/option/v1/simple.proto new file mode 100644 index 000000000000..1750ddfb3ca5 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/resources/test/option/v1/simple.proto @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +import "test/option/v1/option.proto"; + +package test.option.v1; + +message MessageWithOptions { + string test_name = 1; + int32 test_index = 2; + + int32 field_with_fieldoption_double = 700 [(test.option.v1.fieldoption_double) = 100.1]; + int32 field_with_fieldoption_float = 701 [(test.option.v1.fieldoption_float) = 101.2]; + int32 field_with_fieldoption_int32 = 702 [(test.option.v1.fieldoption_int32) = 102]; + int32 field_with_fieldoption_int64 = 703 [(test.option.v1.fieldoption_int64) = 103]; + int32 field_with_fieldoption_uint32 = 704 [(test.option.v1.fieldoption_uint32) = 104]; + int32 field_with_fieldoption_uint64 = 705 [(test.option.v1.fieldoption_uint64) = 105]; + int32 field_with_fieldoption_sint32 = 706 [(test.option.v1.fieldoption_sint32) = 106]; + int32 field_with_fieldoption_sint64 = 707 [(test.option.v1.fieldoption_sint64) = 107]; + int32 field_with_fieldoption_fixed32 = 708; + int32 field_with_fieldoption_fixed64 = 709; + int32 field_with_fieldoption_sfixed32 = 710; + int32 field_with_fieldoption_sfixed64 = 711; + int32 field_with_fieldoption_bool = 712 [(test.option.v1.fieldoption_bool) = true]; + int32 field_with_fieldoption_string = 713 [(test.option.v1.fieldoption_string) = "Oh yeah"]; + int32 field_with_fieldoption_bytes = 714; + int32 field_with_fieldoption_message = 715; + int32 field_with_fieldoption_enum = 716 [(test.option.v1.fieldoption_enum) = ENUM1]; + + int32 field_with_fieldoption_repeated_double = 800; + int32 field_with_fieldoption_repeated_float = 801; + int32 field_with_fieldoption_repeated_int32 = 802; + int32 field_with_fieldoption_repeated_int64 = 803; + int32 field_with_fieldoption_repeated_uint32 = 804; + int32 field_with_fieldoption_repeated_uint64 = 805; + int32 field_with_fieldoption_repeated_sint32 = 806; + int32 field_with_fieldoption_repeated_sint64 = 807; + int32 field_with_fieldoption_repeated_fixed32 = 808; + int32 field_with_fieldoption_repeated_fixed64 = 809; + int32 field_with_fieldoption_repeated_sfixed32 = 810; + int32 field_with_fieldoption_repeated_sfixed64 = 811; + int32 field_with_fieldoption_repeated_bool = 812; + int32 field_with_fieldoption_repeated_string = 813 [(test.option.v1.fieldoption_repeated_string) = "Oh yeah", + (test.option.v1.fieldoption_repeated_string) = "Oh no"]; + int32 field_with_fieldoption_repeated_bytes = 814; + int32 field_with_fieldoption_repeated_message = 815; + int32 field_with_fieldoption_repeated_enum = 816; + +} + From e5d5c7774d2da21f17ef7aeb4030af66f8372ff5 Mon Sep 17 00:00:00 2001 From: Alex Van Boxel Date: Sun, 10 Nov 2019 22:29:20 +0100 Subject: [PATCH 2/2] [BEAM-7274] Split RowWithGetters in two implementations Split RowWithGetters in an implementation with a version that caches collections, named RowWithGettersCachedCollection and one where all the types are handled by the FieldGetters named RowWithGetters. --- .../java/org/apache/beam/sdk/values/Row.java | 7 +- .../beam/sdk/values/RowWithGetters.java | 61 +--------- .../RowWithGettersCachedCollections.java | 114 ++++++++++++++++++ .../sdk/extensions/protobuf/ProtoSchema.java | 23 ++-- 4 files changed, 136 insertions(+), 69 deletions(-) create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGettersCachedCollections.java diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index 8e5a7794bc8e..b8af5acfc855 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -761,8 +761,11 @@ public Row build() { return new RowWithStorage(schema, storageValues); } else if (fieldValueGetterFactory != null) { checkState(getterTarget != null, "getters require withGetterTarget."); - return new RowWithGetters( - schema, fieldValueGetterFactory, getterTarget, collectionHandledByGetter); + if (collectionHandledByGetter) { + return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); + } else { + return new RowWithGettersCachedCollections(schema, fieldValueGetterFactory, getterTarget); + } } else { return new RowWithStorage(schema, Collections.emptyList()); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java index c0fe11f331f2..8a5c11c58866 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.values; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -27,9 +26,6 @@ import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; -import org.apache.beam.sdk.schemas.Schema.TypeName; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; /** * A Concrete subclass of {@link Row} that delegates to a set of provided {@link FieldValueGetter}s. @@ -39,25 +35,16 @@ * the appropriate fields from the POJO. */ public class RowWithGetters extends Row { - private final Factory> fieldValueGetterFactory; - private final Object getterTarget; - private final List getters; - - private final Map cachedLists = Maps.newHashMap(); - private final Map cachedMaps = Maps.newHashMap(); - - private final boolean collectionHandledByGetter; + final Factory> fieldValueGetterFactory; + final Object getterTarget; + final List getters; RowWithGetters( - Schema schema, - Factory> getterFactory, - Object getterTarget, - boolean collectionHandledByGetter) { + Schema schema, Factory> getterFactory, Object getterTarget) { super(schema); this.fieldValueGetterFactory = getterFactory; this.getterTarget = getterTarget; this.getters = fieldValueGetterFactory.create(getterTarget.getClass(), schema); - this.collectionHandledByGetter = collectionHandledByGetter; } @Nullable @@ -73,45 +60,9 @@ public T getValue(int fieldIdx) { return fieldValue != null ? getValue(type, fieldValue, fieldIdx) : null; } - private List getListValue(FieldType elementType, Object fieldValue) { - Iterable iterable = (Iterable) fieldValue; - List list = Lists.newArrayList(); - for (Object o : iterable) { - list.add(getValue(elementType, o, null)); - } - return list; - } - - private Map getMapValue(FieldType keyType, FieldType valueType, Map fieldValue) { - Map returnMap = Maps.newHashMap(); - for (Map.Entry entry : fieldValue.entrySet()) { - returnMap.put( - getValue(keyType, entry.getKey(), null), getValue(valueType, entry.getValue(), null)); - } - return returnMap; - } - @SuppressWarnings({"TypeParameterUnusedInFormals", "unchecked"}) - private T getValue(FieldType type, Object fieldValue, @Nullable Integer cacheKey) { - if (type.getTypeName().equals(TypeName.ROW) && !collectionHandledByGetter) { - return (T) - new RowWithGetters(type.getRowSchema(), fieldValueGetterFactory, fieldValue, false); - } else if (type.getTypeName().equals(TypeName.ARRAY) && !collectionHandledByGetter) { - return cacheKey != null - ? (T) - cachedLists.computeIfAbsent( - cacheKey, i -> getListValue(type.getCollectionElementType(), fieldValue)) - : (T) getListValue(type.getCollectionElementType(), fieldValue); - } else if (type.getTypeName().equals(TypeName.MAP) && !collectionHandledByGetter) { - Map map = (Map) fieldValue; - return cacheKey != null - ? (T) - cachedMaps.computeIfAbsent( - cacheKey, i -> getMapValue(type.getMapKeyType(), type.getMapValueType(), map)) - : (T) getMapValue(type.getMapKeyType(), type.getMapValueType(), map); - } else { - return (T) fieldValue; - } + protected T getValue(FieldType type, Object fieldValue, @Nullable Integer cacheKey) { + return (T) fieldValue; } @Override diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGettersCachedCollections.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGettersCachedCollections.java new file mode 100644 index 000000000000..cbbcc7619857 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGettersCachedCollections.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.values; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import javax.annotation.Nullable; +import org.apache.beam.sdk.schemas.Factory; +import org.apache.beam.sdk.schemas.FieldValueGetter; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; + +/** + * A Concrete subclass of {@link Row} that delegates to a set of provided {@link FieldValueGetter}s. + * This is a special version of {@link RowWithGetters} that cached the map and list collection. + * + *

This allows us to have {@link Row} objects for which the actual storage is in another object. + * For example, the user's type may be a POJO, in which case the provided getters will simple read + * the appropriate fields from the POJO. + */ +public class RowWithGettersCachedCollections extends RowWithGetters { + private final Map cachedLists = Maps.newHashMap(); + private final Map cachedMaps = Maps.newHashMap(); + + RowWithGettersCachedCollections( + Schema schema, Factory> getterFactory, Object getterTarget) { + super(schema, getterFactory, getterTarget); + } + + private List getListValue(FieldType elementType, Object fieldValue) { + Iterable iterable = (Iterable) fieldValue; + List list = Lists.newArrayList(); + for (Object o : iterable) { + list.add(getValue(elementType, o, null)); + } + return list; + } + + private Map getMapValue(FieldType keyType, FieldType valueType, Map fieldValue) { + Map returnMap = Maps.newHashMap(); + for (Map.Entry entry : fieldValue.entrySet()) { + returnMap.put( + getValue(keyType, entry.getKey(), null), getValue(valueType, entry.getValue(), null)); + } + return returnMap; + } + + @SuppressWarnings({"TypeParameterUnusedInFormals", "unchecked"}) + @Override + protected T getValue(FieldType type, Object fieldValue, @Nullable Integer cacheKey) { + if (type.getTypeName().equals(TypeName.ROW)) { + return (T) + new RowWithGettersCachedCollections( + type.getRowSchema(), fieldValueGetterFactory, fieldValue); + } else if (type.getTypeName().equals(TypeName.ARRAY)) { + return cacheKey != null + ? (T) + cachedLists.computeIfAbsent( + cacheKey, i -> getListValue(type.getCollectionElementType(), fieldValue)) + : (T) getListValue(type.getCollectionElementType(), fieldValue); + } else if (type.getTypeName().equals(TypeName.MAP)) { + Map map = (Map) fieldValue; + return cacheKey != null + ? (T) + cachedMaps.computeIfAbsent( + cacheKey, i -> getMapValue(type.getMapKeyType(), type.getMapValueType(), map)) + : (T) getMapValue(type.getMapKeyType(), type.getMapValueType(), map); + } else { + return (T) fieldValue; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null) { + return false; + } + if (o instanceof RowWithGettersCachedCollections) { + RowWithGettersCachedCollections other = (RowWithGettersCachedCollections) o; + return Objects.equals(getSchema(), other.getSchema()) + && Objects.equals(getterTarget, other.getterTarget); + } else if (o instanceof Row) { + return super.equals(o); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(getSchema(), getterTarget); + } +} diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchema.java index b05be3fddc96..351da3ff9f41 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchema.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchema.java @@ -58,28 +58,27 @@ * *
    *
  • Protobuf oneOf fields are mapped to nullable fields and flattened into the parent row. - *
  • Protobuf primitives are mapped to it's nullable counter part. + *
  • Protobuf primitives are mapped to it's non nullable counter part. *
  • Protobuf maps are mapped to nullable maps, where empty maps are mapped to the null value. *
  • Protobuf repeatables are mapped to nullable arrays, where empty arrays are mapped to the * null value. *
  • Protobuf enums are mapped to non-nullable string values. - *
  • Enum map to their string representation *
* *

Protobuf Well Know Types are handled by the Beam Schema system. Beam knows of the following * Well Know Types: * *

    - *
  • google.protobuf.Timestamp maps to a nullable Field.DATATIME. - *
  • google.protobuf.StringValue maps to a nullable Field.STRING. - *
  • google.protobuf.DoubleValue maps to a nullable Field.DOUBLE. - *
  • google.protobuf.FloatValue maps to a nullable Field.FLOAT. - *
  • google.protobuf.BytesValue maps to a nullable Field.BYTES. - *
  • google.protobuf.BoolValue maps to a nullable Field.BOOL. - *
  • google.protobuf.Int64Value maps to a nullable Field.INT64. - *
  • google.protobuf.Int32Value maps to a nullable Field.INT32. - *
  • google.protobuf.UInt64Value maps to a nullable Field.INT64. - *
  • google.protobuf.UInt32Value maps to a nullable Field.INT32. + *
  • google.protobuf.Timestamp maps to a nullable Field.DATATIME + *
  • google.protobuf.StringValue maps to a nullable Field.STRING + *
  • google.protobuf.DoubleValue maps to a nullable Field.DOUBLE + *
  • google.protobuf.FloatValue maps to a nullable Field.FLOAT + *
  • google.protobuf.BytesValue maps to a nullable Field.BYTES + *
  • google.protobuf.BoolValue maps to a nullable Field.BOOL + *
  • google.protobuf.Int64Value maps to a nullable Field.INT64 + *
  • google.protobuf.Int32Value maps to a nullable Field.INT32 + *
  • google.protobuf.UInt64Value maps to a nullable Field.INT64 + *
  • google.protobuf.UInt32Value maps to a nullable Field.INT32 *
*/ @Experimental(Experimental.Kind.SCHEMAS)