From 7a0fe93b2bd28d96a43583e123b09ccd4d0ca033 Mon Sep 17 00:00:00 2001 From: Alex Van Boxel Date: Sat, 11 May 2019 16:56:20 +0200 Subject: [PATCH] [BEAM-7274] Add DynamicMessage Schema support Add DynamicMessage schema support. This is different from generated classes as it uses the proto descriptors. It uses the ProtoDomain as an index for searching embedded messages. --- .../sdk/schemas/logicaltypes/OneOfType.java | 9 +- .../sdk/extensions/protobuf/ProtoDomain.java | 12 +- .../protobuf/ProtoDynamicMessageSchema.java | 844 ++++++++++++++++++ .../protobuf/ProtoSchemaTranslator.java | 66 +- .../ProtoDynamicMessageSchemaTest.java | 282 ++++++ .../extensions/protobuf/TestProtoSchemas.java | 86 +- 6 files changed, 1268 insertions(+), 31 deletions(-) create mode 100644 sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java create mode 100644 sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchemaTest.java diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java index 46b219c027f1..af2747519d8f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java @@ -118,8 +118,13 @@ public FieldType getBaseType() { } /** Create a {@link Value} specifying which field to set and the value to set. */ - public Value createValue(String caseType, T value) { - return createValue(getCaseEnumType().valueOf(caseType), value); + public Value createValue(String caseValue, T value) { + return createValue(getCaseEnumType().valueOf(caseValue), value); + } + + /** Create a {@link Value} specifying which field to set and the value to set. */ + public Value createValue(int caseValue, T value) { + return createValue(getCaseEnumType().valueOf(caseValue), value); } /** Create a {@link Value} specifying which field to set and the value to set. */ diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java index e9a5d48ed35b..c13cf882c547 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java @@ -163,15 +163,15 @@ private void indexDescriptorByName() { .values() .forEach( fileDescriptor -> { - fileDescriptor - .getMessageTypes() - .forEach( - descriptor -> { - descriptorMap.put(descriptor.getFullName(), descriptor); - }); + fileDescriptor.getMessageTypes().forEach(this::indexDescriptor); }); } + private void indexDescriptor(Descriptors.Descriptor descriptor) { + descriptorMap.put(descriptor.getFullName(), descriptor); + descriptor.getNestedTypes().forEach(this::indexDescriptor); + } + private void indexOptionsByNumber(Collection fileDescriptors) { fieldOptionMap = new HashMap<>(); fileOptionMap = new HashMap<>(); diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java new file mode 100644 index 000000000000..93bf00256698 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java @@ -0,0 +1,844 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMapKeyMessageName; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMapValueMessageName; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMessageName; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withMessageName; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Message; +import java.io.Serializable; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.schemas.logicaltypes.NanosDuration; +import org.apache.beam.sdk.schemas.logicaltypes.NanosInstant; +import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; + +@Experimental(Experimental.Kind.SCHEMAS) +public class ProtoDynamicMessageSchema implements Serializable { + public static final long serialVersionUID = 1L; + + /** + * Context of the schema, the context can be generated from a source schema or descriptors. The + * ability of converting back from Row to proto depends on the type of context. + */ + private final Context context; + + /** The toRow function to convert the Message to a Row. */ + private transient SerializableFunction toRowFunction; + + /** The fromRow function to convert the Row to a Message. */ + private transient SerializableFunction fromRowFunction; + + /** List of field converters for each field in the row. */ + private transient List converters; + + private ProtoDynamicMessageSchema(String messageName, ProtoDomain domain) { + this.context = new DescriptorContext(messageName, domain); + readResolve(); + } + + private ProtoDynamicMessageSchema(Context context) { + this.context = context; + readResolve(); + } + + /** + * Create a new ProtoDynamicMessageSchema from a {@link ProtoDomain} and for a message. The + * message need to be in the domain and needs to be the fully qualified name. + */ + public static ProtoDynamicMessageSchema forDescriptor(ProtoDomain domain, String messageName) { + return new ProtoDynamicMessageSchema(messageName, domain); + } + + /** + * Create a new ProtoDynamicMessageSchema from a {@link ProtoDomain} and for a descriptor. The + * descriptor is only used for it's name, that name will be used for a search in the domain. + */ + public static ProtoDynamicMessageSchema forDescriptor( + ProtoDomain domain, Descriptors.Descriptor descriptor) { + return new ProtoDynamicMessageSchema<>(descriptor.getFullName(), domain); + } + + static ProtoDynamicMessageSchema forContext(Context context, Schema.Field field) { + return new ProtoDynamicMessageSchema<>(context.getSubContext(field)); + } + + static ProtoDynamicMessageSchema forSchema(Schema schema) { + return new ProtoDynamicMessageSchema<>(new Context(schema, Message.class)); + } + + /** Initialize the transient fields after deserialization or construction. */ + private Object readResolve() { + converters = createConverters(context.getSchema()); + toRowFunction = new MessageToRowFunction(); + fromRowFunction = new RowToMessageFunction(); + return this; + } + + Convert createConverter(Schema.Field field) { + Schema.FieldType fieldType = field.getType(); + String messageName = getMessageName(fieldType); + if (messageName != null && messageName.length() > 0) { + Schema.Field valueField = + Schema.Field.of("value", withFieldNumber(Schema.FieldType.BOOLEAN, 1)); + switch (messageName) { + case "google.protobuf.StringValue": + case "google.protobuf.DoubleValue": + case "google.protobuf.FloatValue": + case "google.protobuf.BoolValue": + case "google.protobuf.Int64Value": + case "google.protobuf.Int32Value": + case "google.protobuf.UInt64Value": + case "google.protobuf.UInt32Value": + return new WrapperConvert(field, new PrimitiveConvert(valueField)); + case "google.protobuf.BytesValue": + return new WrapperConvert(field, new BytesConvert(valueField)); + case "google.protobuf.Timestamp": + case "google.protobuf.Duration": + // handled by logical type case + break; + } + } + switch (fieldType.getTypeName()) { + case BYTE: + case INT16: + case INT32: + case INT64: + case FLOAT: + case DOUBLE: + case STRING: + case BOOLEAN: + return new PrimitiveConvert(field); + case BYTES: + return new BytesConvert(field); + case ARRAY: + case ITERABLE: + return new ArrayConvert(this, field); + case MAP: + return new MapConvert(this, field); + case LOGICAL_TYPE: + String identifier = field.getType().getLogicalType().getIdentifier(); + switch (identifier) { + case ProtoSchemaLogicalTypes.Fixed32.IDENTIFIER: + case ProtoSchemaLogicalTypes.Fixed64.IDENTIFIER: + case ProtoSchemaLogicalTypes.SFixed32.IDENTIFIER: + case ProtoSchemaLogicalTypes.SFixed64.IDENTIFIER: + case ProtoSchemaLogicalTypes.SInt32.IDENTIFIER: + case ProtoSchemaLogicalTypes.SInt64.IDENTIFIER: + case ProtoSchemaLogicalTypes.UInt32.IDENTIFIER: + case ProtoSchemaLogicalTypes.UInt64.IDENTIFIER: + return new LogicalTypeConvert(field, fieldType.getLogicalType()); + case NanosInstant.IDENTIFIER: + return new TimestampConvert(field); + case NanosDuration.IDENTIFIER: + return new DurationConvert(field); + case EnumerationType.IDENTIFIER: + return new EnumConvert(field, fieldType.getLogicalType()); + case OneOfType.IDENTIFIER: + return new OneOfConvert(this, field, fieldType.getLogicalType()); + default: + throw new IllegalStateException("Unexpected logical type : " + identifier); + } + case ROW: + return new MessageConvert(this, field); + default: + throw new IllegalStateException("Unexpected value: " + fieldType); + } + } + + private List createConverters(Schema schema) { + List fieldOverlays = new ArrayList<>(); + for (Schema.Field field : schema.getFields()) { + fieldOverlays.add(createConverter(field)); + } + return fieldOverlays; + } + + public Schema getSchema() { + return context.getSchema(); + } + + public SerializableFunction getToRowFunction() { + return toRowFunction; + } + + public SerializableFunction getFromRowFunction() { + return fromRowFunction; + } + + /** + * Context that only has enough information to convert a proto message to a Row. This can be used + * for arbitrary conventions, like decoding messages in proto options. + */ + static class Context implements Serializable { + private final Schema schema; + + /** + * Base class for the protobuf message. Normally this is DynamicMessage, but as this schema + * class is also used to decode protobuf options this can be normal Message instances. + */ + private Class baseClass; + + Context(Schema schema, Class baseClass) { + this.schema = schema; + this.baseClass = baseClass; + } + + public Schema getSchema() { + return schema; + } + + public Class getBaseClass() { + return baseClass; + } + + public DynamicMessage.Builder invokeNewBuilder() { + throw new IllegalStateException("Should not be calling invokeNewBuilder"); + } + + public Context getSubContext(Schema.Field field) { + return new Context(field.getType().getRowSchema(), Message.class); + } + } + + /** + * Context the contains the full {@link ProtoDomain} and a reference to the message name. The full + * domain is needed for creating Rows back to the original proto messages. + */ + static class DescriptorContext extends Context { + private final String messageName; + private final ProtoDomain domain; + private transient Descriptors.Descriptor descriptor; + + DescriptorContext(String messageName, ProtoDomain domain) { + super( + ProtoSchemaTranslator.getSchema(domain.getDescriptor(messageName)), DynamicMessage.class); + this.messageName = messageName; + this.domain = domain; + } + + @Override + public DynamicMessage.Builder invokeNewBuilder() { + if (descriptor == null) { + descriptor = domain.getDescriptor(messageName); + } + return DynamicMessage.newBuilder(descriptor); + } + + @Override + public Context getSubContext(Schema.Field field) { + String messageName = getMessageName(field.getType()); + return new DescriptorContext(messageName, domain); + } + } + + /** + * Base converter class for converting from proto values to row values. The converter mainly works + * on fields in proto messages but also has methods to convert individual elements (example, for + * elements in Lists or Maps). + */ + abstract static class Convert { + private int number; + + Convert(Schema.Field field) { + try { + this.number = getFieldNumber(field.getType()); + } catch (NumberFormatException e) { + this.number = -1; + } + } + + FieldDescriptor getFieldDescriptor(Message message) { + return message.getDescriptorForType().findFieldByNumber(number); + } + + FieldDescriptor getFieldDescriptor(Message.Builder message) { + return message.getDescriptorForType().findFieldByNumber(number); + } + + /** Get a proto field and convert it into a row value. */ + abstract Object getFromProtoMessage(Message message); + + /** Convert a proto value into a row value. */ + abstract ValueT convertFromProtoValue(Object object); + + /** Convert a row value and set it on a proto message. */ + abstract void setOnProtoMessage(Message.Builder object, InT value); + + /** Convert a row value into a proto value. */ + abstract Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value); + } + + /** Converter for primitive proto values. */ + static class PrimitiveConvert extends Convert { + PrimitiveConvert(Schema.Field field) { + super(field); + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + + @Override + Object convertFromProtoValue(Object object) { + return object; + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + message.setField(getFieldDescriptor(message), value); + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** + * Converter for Bytes. Protobuf Bytes are natively represented as ByteStrings that requires + * special handling for byte[] of size 0. + */ + static class BytesConvert extends PrimitiveConvert { + BytesConvert(Schema.Field field) { + super(field); + } + + @Override + Object convertFromProtoValue(Object object) { + // return object; + return ((ByteString) object).toByteArray(); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null && ((byte[]) value).length > 0) { + // Protobuf messages BYTES doesn't like empty bytes?! + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertToProtoValue(fieldDescriptor, value)); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + if (value != null) { + return ByteString.copyFrom((byte[]) value); + } + return null; + } + } + + /** + * Specific converter for Proto Wrapper values as they are translated into nullable row values. + */ + static class WrapperConvert extends Convert { + private Convert valueConvert; + + WrapperConvert(Schema.Field field, Convert valueConvert) { + super(field); + this.valueConvert = valueConvert; + } + + @Override + Object getFromProtoMessage(Message message) { + if (message.hasField(getFieldDescriptor(message))) { + Message wrapper = (Message) message.getField(getFieldDescriptor(message)); + return valueConvert.getFromProtoMessage(wrapper); + } + return null; + } + + @Override + Object convertFromProtoValue(Object object) { + return object; + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + DynamicMessage.Builder builder = + DynamicMessage.newBuilder(getFieldDescriptor(message).getMessageType()); + valueConvert.setOnProtoMessage(builder, value); + message.setField(getFieldDescriptor(message), builder.build()); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + static class TimestampConvert extends Convert { + + TimestampConvert(Schema.Field field) { + super(field); + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + if (message.hasField(fieldDescriptor)) { + Message wrapper = (Message) message.getField(fieldDescriptor); + return convertFromProtoValue(wrapper); + } + return null; + } + + @Override + Object convertFromProtoValue(Object object) { + Message timestamp = (Message) object; + Descriptors.Descriptor timestampDescriptor = timestamp.getDescriptorForType(); + FieldDescriptor secondField = timestampDescriptor.findFieldByNumber(1); + FieldDescriptor nanoField = timestampDescriptor.findFieldByNumber(2); + long second = (long) timestamp.getField(secondField); + int nano = (int) timestamp.getField(nanoField); + return Instant.ofEpochSecond(second, nano); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertToProtoValue(fieldDescriptor, value)); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + Row row = (Row) value; + return com.google.protobuf.Timestamp.newBuilder() + .setSeconds(row.getInt64(0)) + .setNanos(row.getInt32(1)) + .build(); + } + } + + static class DurationConvert extends Convert { + + DurationConvert(Schema.Field field) { + super(field); + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + if (message.hasField(fieldDescriptor)) { + Message wrapper = (Message) message.getField(fieldDescriptor); + return convertFromProtoValue(wrapper); + } + return null; + } + + @Override + Duration convertFromProtoValue(Object object) { + Message timestamp = (Message) object; + Descriptors.Descriptor timestampDescriptor = timestamp.getDescriptorForType(); + FieldDescriptor secondField = timestampDescriptor.findFieldByNumber(1); + FieldDescriptor nanoField = timestampDescriptor.findFieldByNumber(2); + long second = (long) timestamp.getField(secondField); + int nano = (int) timestamp.getField(nanoField); + return Duration.ofSeconds(second, nano); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertToProtoValue(fieldDescriptor, value)); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + Row row = (Row) value; + return com.google.protobuf.Duration.newBuilder() + .setSeconds(row.getInt64(0)) + .setNanos(row.getInt32(1)) + .build(); + } + } + + static class MessageConvert extends Convert { + private final SerializableFunction fromRowFunction; + private final SerializableFunction toRowFunction; + + MessageConvert(ProtoDynamicMessageSchema rootProtoSchema, Schema.Field field) { + super(field); + ProtoDynamicMessageSchema protoSchema = + ProtoDynamicMessageSchema.forContext(rootProtoSchema.context, field); + toRowFunction = protoSchema.getToRowFunction(); + fromRowFunction = protoSchema.getFromRowFunction(); + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + if (message.hasField(fieldDescriptor)) { + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + return null; + } + + @Override + Object convertFromProtoValue(Object object) { + return toRowFunction.apply(object); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertToProtoValue(fieldDescriptor, value)); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return fromRowFunction.apply(value); + } + } + + /** + * Proto has a well defined way of storing maps, by having a Message with two fields, named "key" + * and "value" in a repeatable field. This overlay translates between Row.map and the Protobuf + * map. + */ + static class MapConvert extends Convert { + private Convert key; + private Convert value; + + MapConvert(ProtoDynamicMessageSchema protoSchema, Schema.Field field) { + super(field); + Schema.FieldType fieldType = field.getType(); + key = + protoSchema.createConverter( + Schema.Field.of( + "KEY", + withMessageName(fieldType.getMapKeyType(), getMapKeyMessageName(fieldType)))); + value = + protoSchema.createConverter( + Schema.Field.of( + "VALUE", + withMessageName(fieldType.getMapValueType(), getMapValueMessageName(fieldType)))); + } + + @Override + Map getFromProtoMessage(Message message) { + List list = (List) message.getField(getFieldDescriptor(message)); + if (list.size() == 0) { + return null; + } + Map rowMap = new HashMap<>(); + list.forEach( + entryMessage -> { + Descriptors.Descriptor entryDescriptor = entryMessage.getDescriptorForType(); + FieldDescriptor keyFieldDescriptor = entryDescriptor.findFieldByName("key"); + FieldDescriptor valueFieldDescriptor = entryDescriptor.findFieldByName("value"); + rowMap.put( + key.convertFromProtoValue(entryMessage.getField(keyFieldDescriptor)), + this.value.convertFromProtoValue(entryMessage.getField(valueFieldDescriptor))); + }); + return rowMap; + } + + @Override + Map convertFromProtoValue(Object object) { + throw new RuntimeException("?"); + } + + @Override + void setOnProtoMessage(Message.Builder message, Map map) { + if (map != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + List messageMap = new ArrayList<>(); + map.forEach( + (k, v) -> { + DynamicMessage.Builder builder = + DynamicMessage.newBuilder(fieldDescriptor.getMessageType()); + FieldDescriptor keyFieldDescriptor = + fieldDescriptor.getMessageType().findFieldByName("key"); + builder.setField( + keyFieldDescriptor, this.key.convertToProtoValue(keyFieldDescriptor, k)); + FieldDescriptor valueFieldDescriptor = + fieldDescriptor.getMessageType().findFieldByName("value"); + builder.setField( + valueFieldDescriptor, value.convertToProtoValue(valueFieldDescriptor, v)); + messageMap.add(builder.build()); + }); + message.setField(fieldDescriptor, messageMap); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + static class ArrayConvert extends Convert { + private Convert element; + + ArrayConvert(ProtoDynamicMessageSchema protoSchema, Schema.Field field) { + super(field); + Schema.FieldType collectionElementType = field.getType().getCollectionElementType(); + this.element = + protoSchema.createConverter( + Schema.Field.of( + "ELEMENT", + withMessageName(collectionElementType, getMessageName(field.getType())))); + } + + @Override + List getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + + @Override + List convertFromProtoValue(Object value) { + List list = (List) value; + List arrayList = new ArrayList<>(); + list.forEach( + entry -> { + arrayList.add(element.convertFromProtoValue(entry)); + }); + return arrayList; + } + + @Override + void setOnProtoMessage(Message.Builder message, List list) { + if (list != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + List targetList = new ArrayList<>(); + list.forEach( + (e) -> { + targetList.add(element.convertToProtoValue(fieldDescriptor, e)); + }); + message.setField(fieldDescriptor, targetList); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** Enum overlay handles the conversion between a string and a ProtoBuf Enum. */ + static class EnumConvert extends Convert { + EnumerationType logicalType; + + EnumConvert(Schema.Field field, Schema.LogicalType logicalType) { + super(field); + this.logicalType = (EnumerationType) logicalType; + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + + @Override + EnumerationType.Value convertFromProtoValue(Object in) { + return logicalType.valueOf(((Descriptors.EnumValueDescriptor) in).getNumber()); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertToProtoValue(fieldDescriptor, value)); + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + Descriptors.EnumDescriptor enumType = fieldDescriptor.getEnumType(); + return enumType.findValueByNumber((Integer) value); + } + } + + /** Convert Proto oneOf fields into the {@link OneOfType} logical type. */ + static class OneOfConvert extends Convert { + OneOfType logicalType; + Map oneOfConvert = new HashMap<>(); + + OneOfConvert( + ProtoDynamicMessageSchema protoSchema, Schema.Field field, Schema.LogicalType logicalType) { + super(field); + this.logicalType = (OneOfType) logicalType; + for (Schema.Field oneOfField : this.logicalType.getOneOfSchema().getFields()) { + int fieldNumber = getFieldNumber(oneOfField.getType()); + oneOfConvert.put( + fieldNumber, new NullableConvert(oneOfField, protoSchema.createConverter(oneOfField))); + } + } + + @Override + Object getFromProtoMessage(Message message) { + for (Map.Entry entry : this.oneOfConvert.entrySet()) { + Object value = entry.getValue().getFromProtoMessage(message); + if (value != null) { + return logicalType.createValue(entry.getKey(), value); + } + } + return null; + } + + @Override + OneOfType.Value convertFromProtoValue(Object in) { + throw new IllegalStateException("Value conversion can't be done outside a protobuf message"); + } + + @Override + void setOnProtoMessage(Message.Builder message, Row value) { + OneOfType.Value oneOf = logicalType.toInputType(value); + int caseIndex = oneOf.getCaseType().getValue(); + oneOfConvert.get(caseIndex).setOnProtoMessage(message, oneOf.getValue()); + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + throw new IllegalStateException("Value conversion can't be done outside a protobuf message"); + } + } + + /** + * This overlay handles nullable fields. If a primitive field needs to be nullable this overlay is + * wrapped around the original overlay. + */ + static class NullableConvert extends Convert { + + private Convert fieldOverlay; + + NullableConvert(Schema.Field field, Convert fieldOverlay) { + super(field); + this.fieldOverlay = fieldOverlay; + } + + @Override + Object getFromProtoMessage(Message message) { + if (message.hasField(getFieldDescriptor(message))) { + return fieldOverlay.getFromProtoMessage(message); + } + return null; + } + + @Override + Object convertFromProtoValue(Object object) { + throw new IllegalStateException("Value conversion can't be done outside a protobuf message"); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + fieldOverlay.setOnProtoMessage(message, value); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + throw new IllegalStateException("Value conversion can't be done outside a protobuf message"); + } + } + + static class LogicalTypeConvert extends Convert { + + private Schema.LogicalType logicalType; + + LogicalTypeConvert(Schema.Field field, Schema.LogicalType logicalType) { + super(field); + this.logicalType = logicalType; + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + + @Override + Object convertFromProtoValue(Object object) { + return logicalType.toBaseType(object); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + message.setField(getFieldDescriptor(message), value); + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + private class MessageToRowFunction implements SerializableFunction { + + private MessageToRowFunction() {} + + @Override + public Row apply(T input) { + Schema schema = context.getSchema(); + Row.Builder builder = Row.withSchema(schema); + for (Convert convert : converters) { + builder.addValue(convert.getFromProtoMessage((Message) input)); + } + return builder.build(); + } + } + + private class RowToMessageFunction implements SerializableFunction { + + private RowToMessageFunction() {} + + @Override + public T apply(Row input) { + DynamicMessage.Builder builder = context.invokeNewBuilder(); + Iterator values = input.getValues().iterator(); + Iterator convertIterator = converters.iterator(); + + for (int i = 0; i < input.getValues().size(); i++) { + Convert convert = convertIterator.next(); + Object value = values.next(); + convert.setOnProtoMessage(builder, value); + } + return (T) builder.build(); + } + } +} diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java index 7b9930832806..bc7503fccc65 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java @@ -132,6 +132,13 @@ public class ProtoSchemaTranslator { /** This METADATA tag is used to store the field number of a proto tag. */ public static final String PROTO_NUMBER_METADATA_TAG = "PROTO_NUMBER"; + public static final String PROTO_MESSAGE_NAME_METADATA_TAG = "PROTO_MESSAGE_NAME"; + + public static final String PROTO_MAP_KEY_MESSAGE_NAME_METADATA_TAG = "PROTO_MAP_KEY_MESSAGE_NAME"; + + public static final String PROTO_MAP_VALUE_MESSAGE_NAME_METADATA_TAG = + "PROTO_MAP_VALUE_MESSAGE_NAME"; + /** Attach a proto field number to a type. */ public static FieldType withFieldNumber(FieldType fieldType, int index) { return fieldType.withMetadata(PROTO_NUMBER_METADATA_TAG, Long.toString(index)); @@ -142,12 +149,42 @@ public static int getFieldNumber(FieldType fieldType) { return Integer.parseInt(fieldType.getMetadataString(PROTO_NUMBER_METADATA_TAG)); } + /** Attach the name of the message to a type. */ + public static FieldType withMessageName(FieldType fieldType, String messageName) { + return fieldType.withMetadata(PROTO_MESSAGE_NAME_METADATA_TAG, messageName); + } + + /** Return the message name for a type. */ + public static String getMessageName(FieldType fieldType) { + return fieldType.getMetadataString(PROTO_MESSAGE_NAME_METADATA_TAG); + } + + /** Attach the name of the message to a map key. */ + public static FieldType withMapKeyMessageName(FieldType fieldType, String messageName) { + return fieldType.withMetadata(PROTO_MAP_KEY_MESSAGE_NAME_METADATA_TAG, messageName); + } + + /** Return the message name for a map key. */ + public static String getMapKeyMessageName(FieldType fieldType) { + return fieldType.getMetadataString(PROTO_MAP_KEY_MESSAGE_NAME_METADATA_TAG); + } + + /** Attach the name of the message to a map value. */ + public static FieldType withMapValueMessageName(FieldType fieldType, String messageName) { + return fieldType.withMetadata(PROTO_MAP_VALUE_MESSAGE_NAME_METADATA_TAG, messageName); + } + + /** Return the message name for a map value. */ + public static String getMapValueMessageName(FieldType fieldType) { + return fieldType.getMetadataString(PROTO_MAP_VALUE_MESSAGE_NAME_METADATA_TAG); + } + /** Return a Beam scheam representing a proto class. */ public static Schema getSchema(Class clazz) { return getSchema(ProtobufUtil.getDescriptorForClass(clazz)); } - private static Schema getSchema(Descriptors.Descriptor descriptor) { + static Schema getSchema(Descriptors.Descriptor descriptor) { Set oneOfFields = Sets.newHashSet(); List fields = Lists.newArrayListWithCapacity(descriptor.getFields().size()); for (OneofDescriptor oneofDescriptor : descriptor.getOneofs()) { @@ -157,8 +194,7 @@ private static Schema getSchema(Descriptors.Descriptor descriptor) { oneOfFields.add(fieldDescriptor.getNumber()); // Store proto field number in metadata. FieldType fieldType = - withFieldNumber( - beamFieldTypeFromProtoField(fieldDescriptor), fieldDescriptor.getNumber()); + withMetaData(beamFieldTypeFromProtoField(fieldDescriptor), fieldDescriptor); subFields.add(Field.nullable(fieldDescriptor.getName(), fieldType)); checkArgument( enumIds.putIfAbsent(fieldDescriptor.getName(), fieldDescriptor.getNumber()) == null); @@ -171,14 +207,34 @@ private static Schema getSchema(Descriptors.Descriptor descriptor) { if (!oneOfFields.contains(fieldDescriptor.getNumber())) { // Store proto field number in metadata. FieldType fieldType = - withFieldNumber( - beamFieldTypeFromProtoField(fieldDescriptor), fieldDescriptor.getNumber()); + withMetaData(beamFieldTypeFromProtoField(fieldDescriptor), fieldDescriptor); fields.add(Field.of(fieldDescriptor.getName(), fieldType)); } } return Schema.builder().addFields(fields).build(); } + private static FieldType withMetaData( + FieldType inType, Descriptors.FieldDescriptor fieldDescriptor) { + FieldType fieldType = withFieldNumber(inType, fieldDescriptor.getNumber()); + if (fieldDescriptor.isMapField()) { + FieldDescriptor keyFieldDescriptor = fieldDescriptor.getMessageType().findFieldByName("key"); + FieldDescriptor valueFieldDescriptor = + fieldDescriptor.getMessageType().findFieldByName("value"); + if ((keyFieldDescriptor.getType() == FieldDescriptor.Type.MESSAGE)) { + fieldType = + withMapKeyMessageName(fieldType, keyFieldDescriptor.getMessageType().getFullName()); + } + if ((valueFieldDescriptor.getType() == FieldDescriptor.Type.MESSAGE)) { + fieldType = + withMapValueMessageName(fieldType, valueFieldDescriptor.getMessageType().getFullName()); + } + } else if (fieldDescriptor.getType() == FieldDescriptor.Type.MESSAGE) { + return withMessageName(fieldType, fieldDescriptor.getMessageType().getFullName()); + } + return fieldType; + } + private static FieldType beamFieldTypeFromProtoField( Descriptors.FieldDescriptor protoFieldDescriptor) { FieldType fieldType = null; diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchemaTest.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchemaTest.java new file mode 100644 index 000000000000..14bd131982cb --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchemaTest.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_BOOL; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_INT32; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_PRIMITIVE; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_STRING; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_BOOL; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_INT32; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_PRIMITIVE; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_STRING; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.OUTER_ONEOF_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.OUTER_ONEOF_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.OUTER_ONEOF_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.PRIMITIVE_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.PRIMITIVE_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.PRIMITIVE_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_SCHEMA; +import static org.junit.Assert.assertEquals; + +import com.google.protobuf.Descriptors; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.EnumMessage; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.MapPrimitive; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.Nested; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.OneOf; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.OuterOneOf; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.Primitive; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.RepeatPrimitive; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.WktMessage; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Collection of tests for values on Protobuf Messages and Rows. */ +@RunWith(JUnit4.class) +public class ProtoDynamicMessageSchemaTest { + + private ProtoDynamicMessageSchema schemaFromDescriptor(Descriptors.Descriptor descriptor) { + ProtoDomain domain = ProtoDomain.buildFrom(descriptor); + return ProtoDynamicMessageSchema.forDescriptor(domain, descriptor); + } + + private DynamicMessage toDynamic(Message message) throws InvalidProtocolBufferException { + return DynamicMessage.parseFrom(message.getDescriptorForType(), message.toByteArray()); + } + + @Test + public void testPrimitiveSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Primitive.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(PRIMITIVE_SCHEMA, schema); + } + + @Test + public void testPrimitiveProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Primitive.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(PRIMITIVE_ROW, toRow.apply(toDynamic(PRIMITIVE_PROTO))); + } + + @Test + public void testPrimitiveRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Primitive.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(PRIMITIVE_PROTO.toString(), fromRow.apply(PRIMITIVE_ROW).toString()); + } + + @Test + public void testRepeatedSchema() { + ProtoDynamicMessageSchema schemaProvider = + schemaFromDescriptor(RepeatPrimitive.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(REPEATED_SCHEMA, schema); + } + + @Test + public void testRepeatedProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = + schemaFromDescriptor(RepeatPrimitive.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(REPEATED_ROW, toRow.apply(toDynamic(REPEATED_PROTO))); + } + + @Test + public void testRepeatedRowToProto() { + ProtoDynamicMessageSchema schemaProvider = + schemaFromDescriptor(RepeatPrimitive.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(REPEATED_PROTO.toString(), fromRow.apply(REPEATED_ROW).toString()); + } + + // Test map type + @Test + public void testMapSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(MapPrimitive.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(MAP_PRIMITIVE_SCHEMA, schema); + } + + @Test + public void testMapProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(MapPrimitive.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(MAP_PRIMITIVE_ROW, toRow.apply(toDynamic(MAP_PRIMITIVE_PROTO))); + } + + @Test + public void testMapRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(MapPrimitive.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(MAP_PRIMITIVE_PROTO.toString(), fromRow.apply(MAP_PRIMITIVE_ROW).toString()); + } + + @Test + public void testNestedSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Nested.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(NESTED_SCHEMA, schema); + } + + @Test + public void testNestedProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Nested.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(NESTED_ROW, toRow.apply(toDynamic(NESTED_PROTO))); + } + + @Test + public void testNestedRowToProto() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Nested.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + // equality doesn't work between dynamic messages and other, + // so we compare string representation + assertEquals(NESTED_PROTO.toString(), fromRow.apply(NESTED_ROW).toString()); + } + + @Test + public void testOneOfSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OneOf.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(ONEOF_SCHEMA, schema); + } + + @Test + public void testOneOfProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OneOf.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + // equality doesn't work between dynamic messages and other, + // so we compare string representation + assertEquals(ONEOF_ROW_INT32.toString(), toRow.apply(toDynamic(ONEOF_PROTO_INT32)).toString()); + assertEquals(ONEOF_ROW_BOOL.toString(), toRow.apply(toDynamic(ONEOF_PROTO_BOOL)).toString()); + assertEquals( + ONEOF_ROW_STRING.toString(), toRow.apply(toDynamic(ONEOF_PROTO_STRING)).toString()); + assertEquals( + ONEOF_ROW_PRIMITIVE.toString(), toRow.apply(toDynamic(ONEOF_PROTO_PRIMITIVE)).toString()); + } + + @Test + public void testOneOfRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OneOf.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(ONEOF_PROTO_INT32.toString(), fromRow.apply(ONEOF_ROW_INT32).toString()); + assertEquals(ONEOF_PROTO_BOOL.toString(), fromRow.apply(ONEOF_ROW_BOOL).toString()); + assertEquals(ONEOF_PROTO_STRING.toString(), fromRow.apply(ONEOF_ROW_STRING).toString()); + assertEquals(ONEOF_PROTO_PRIMITIVE.toString(), fromRow.apply(ONEOF_ROW_PRIMITIVE).toString()); + } + + @Test + public void testOuterOneOfSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OuterOneOf.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(OUTER_ONEOF_SCHEMA, schema); + } + + @Test + public void testOuterOneOfProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OuterOneOf.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + // equality doesn't work between dynamic messages and other, + // so we compare string representation + assertEquals(OUTER_ONEOF_ROW.toString(), toRow.apply(toDynamic(OUTER_ONEOF_PROTO)).toString()); + } + + @Test + public void testOuterOneOfRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OuterOneOf.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(OUTER_ONEOF_PROTO.toString(), fromRow.apply(OUTER_ONEOF_ROW).toString()); + } + + private static final EnumerationType ENUM_TYPE = + EnumerationType.create(ImmutableMap.of("ZERO", 0, "TWO", 2, "THREE", 3)); + private static final Schema ENUM_SCHEMA = + Schema.builder() + .addField( + "enum", + withFieldNumber(Schema.FieldType.logicalType(ENUM_TYPE).withNullable(false), 1)) + .build(); + private static final Row ENUM_ROW = + Row.withSchema(ENUM_SCHEMA).addValues(ENUM_TYPE.valueOf("TWO")).build(); + private static final EnumMessage ENUM_PROTO = + EnumMessage.newBuilder().setEnum(EnumMessage.Enum.TWO).build(); + + @Test + public void testEnumSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(EnumMessage.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(ENUM_SCHEMA, schema); + } + + @Test + public void testEnumProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(EnumMessage.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(ENUM_ROW, toRow.apply(toDynamic(ENUM_PROTO))); + } + + @Test + public void testEnumRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(EnumMessage.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(ENUM_PROTO.toString(), fromRow.apply(ENUM_ROW).toString()); + } + + @Test + public void testWktMessageSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(WktMessage.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(WKT_MESSAGE_SCHEMA, schema); + } + + @Test + public void testWktProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(WktMessage.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(WKT_MESSAGE_ROW, toRow.apply(toDynamic(WKT_MESSAGE_PROTO))); + } + + @Test + public void testWktRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(WktMessage.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(WKT_MESSAGE_PROTO.toString(), fromRow.apply(WKT_MESSAGE_ROW).toString()); + } +} diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java index 10ed0ece2706..862d637c8506 100644 --- a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java @@ -19,6 +19,8 @@ import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getFieldNumber; import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withMapValueMessageName; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withMessageName; import com.google.protobuf.BoolValue; import com.google.protobuf.ByteString; @@ -276,12 +278,21 @@ class TestProtoSchemas { static final Schema NESTED_SCHEMA = Schema.builder() .addField( - "nested", withFieldNumber(FieldType.row(PRIMITIVE_SCHEMA).withNullable(true), 1)) + "nested", + withMessageName( + withFieldNumber(FieldType.row(PRIMITIVE_SCHEMA).withNullable(true), 1), + "proto3_schema_messages.Primitive")) .addField( - "nested_list", withFieldNumber(FieldType.array(FieldType.row(PRIMITIVE_SCHEMA)), 2)) + "nested_list", + withMessageName( + withFieldNumber(FieldType.array(FieldType.row(PRIMITIVE_SCHEMA)), 2), + "proto3_schema_messages.Primitive")) .addField( "nested_map", - withFieldNumber(FieldType.map(FieldType.STRING, FieldType.row(PRIMITIVE_SCHEMA)), 3)) + withMapValueMessageName( + withFieldNumber( + FieldType.map(FieldType.STRING, FieldType.row(PRIMITIVE_SCHEMA)), 3), + "proto3_schema_messages.Primitive")) .build(); // A sample instance of the row. @@ -306,7 +317,11 @@ class TestProtoSchemas { Field.of("oneof_int32", withFieldNumber(FieldType.INT32, 2)), Field.of("oneof_bool", withFieldNumber(FieldType.BOOLEAN, 3)), Field.of("oneof_string", withFieldNumber(FieldType.STRING, 4)), - Field.of("oneof_primitive", withFieldNumber(FieldType.row(PRIMITIVE_SCHEMA), 5))); + Field.of( + "oneof_primitive", + withMessageName( + withFieldNumber(FieldType.row(PRIMITIVE_SCHEMA), 5), + "proto3_schema_messages.Primitive"))); private static final Map ONE_OF_ENUM_MAP = ONEOF_FIELDS.stream() .collect(Collectors.toMap(Field::getName, f -> getFieldNumber(f.getType()))); @@ -349,7 +364,10 @@ class TestProtoSchemas { // The schema for the OuterOneOf proto. private static final List OUTER_ONEOF_FIELDS = ImmutableList.of( - Field.of("oneof_oneof", withFieldNumber(FieldType.row(ONEOF_SCHEMA), 1)), + Field.of( + "oneof_oneof", + withMessageName( + withFieldNumber(FieldType.row(ONEOF_SCHEMA), 1), "proto3_schema_messages.OneOf")), Field.of("oneof_int32", withFieldNumber(FieldType.INT32, 2))); private static final Map OUTER_ONE_OF_ENUM_MAP = OUTER_ONEOF_FIELDS.stream() @@ -371,19 +389,47 @@ class TestProtoSchemas { static final Schema WKT_MESSAGE_SCHEMA = Schema.builder() - .addNullableField("double", withFieldNumber(FieldType.DOUBLE, 1)) - .addNullableField("float", withFieldNumber(FieldType.FLOAT, 2)) - .addNullableField("int32", withFieldNumber(FieldType.INT32, 3)) - .addNullableField("int64", withFieldNumber(FieldType.INT64, 4)) - .addNullableField("uint32", withFieldNumber(FieldType.logicalType(new UInt32()), 5)) - .addNullableField("uint64", withFieldNumber(FieldType.logicalType(new UInt64()), 6)) - .addNullableField("bool", withFieldNumber(FieldType.BOOLEAN, 13)) - .addNullableField("string", withFieldNumber(FieldType.STRING, 14)) - .addNullableField("bytes", withFieldNumber(FieldType.BYTES, 15)) .addNullableField( - "timestamp", withFieldNumber(FieldType.logicalType(new NanosInstant()), 16)) + "double", + withMessageName(withFieldNumber(FieldType.DOUBLE, 1), "google.protobuf.DoubleValue")) + .addNullableField( + "float", + withMessageName(withFieldNumber(FieldType.FLOAT, 2), "google.protobuf.FloatValue")) + .addNullableField( + "int32", + withMessageName(withFieldNumber(FieldType.INT32, 3), "google.protobuf.Int32Value")) + .addNullableField( + "int64", + withMessageName(withFieldNumber(FieldType.INT64, 4), "google.protobuf.Int64Value")) + .addNullableField( + "uint32", + withMessageName( + withFieldNumber(FieldType.logicalType(new UInt32()), 5), + "google.protobuf.UInt32Value")) + .addNullableField( + "uint64", + withMessageName( + withFieldNumber(FieldType.logicalType(new UInt64()), 6), + "google.protobuf.UInt64Value")) + .addNullableField( + "bool", + withMessageName(withFieldNumber(FieldType.BOOLEAN, 13), "google.protobuf.BoolValue")) + .addNullableField( + "string", + withMessageName(withFieldNumber(FieldType.STRING, 14), "google.protobuf.StringValue")) + .addNullableField( + "bytes", + withMessageName(withFieldNumber(FieldType.BYTES, 15), "google.protobuf.BytesValue")) + .addNullableField( + "timestamp", + withMessageName( + withFieldNumber(FieldType.logicalType(new NanosInstant()), 16), + "google.protobuf.Timestamp")) .addNullableField( - "duration", withFieldNumber(FieldType.logicalType(new NanosDuration()), 17)) + "duration", + withMessageName( + withFieldNumber(FieldType.logicalType(new NanosDuration()), 17), + "google.protobuf.Duration")) .build(); // A sample instance of the row. static final Instant JAVA_NOW = Instant.now(); @@ -426,7 +472,9 @@ class TestProtoSchemas { Schema.builder() .addField( "nested", - withFieldNumber(FieldType.row(OPTIONAL_PRIMITIVE_SCHEMA), 1).withNullable(true)) + withMessageName( + withFieldNumber(FieldType.row(OPTIONAL_PRIMITIVE_SCHEMA), 1).withNullable(true), + "proto2_schema_messages.OptionalPrimitive")) .build(); // A sample instance of the proto. @@ -438,7 +486,9 @@ class TestProtoSchemas { Schema.builder() .addField( "nested", - withFieldNumber(FieldType.row(REQUIRED_PRIMITIVE_SCHEMA), 1).withNullable(false)) + withMessageName( + withFieldNumber(FieldType.row(REQUIRED_PRIMITIVE_SCHEMA), 1).withNullable(false), + "proto2_schema_messages.RequiredPrimitive")) .build(); // A sample instance of the proto.