diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java index 46b219c027f1..af2747519d8f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java @@ -118,8 +118,13 @@ public FieldType getBaseType() { } /** Create a {@link Value} specifying which field to set and the value to set. */ - public Value createValue(String caseType, T value) { - return createValue(getCaseEnumType().valueOf(caseType), value); + public Value createValue(String caseValue, T value) { + return createValue(getCaseEnumType().valueOf(caseValue), value); + } + + /** Create a {@link Value} specifying which field to set and the value to set. */ + public Value createValue(int caseValue, T value) { + return createValue(getCaseEnumType().valueOf(caseValue), value); } /** Create a {@link Value} specifying which field to set and the value to set. */ diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java index e9a5d48ed35b..c13cf882c547 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java @@ -163,15 +163,15 @@ private void indexDescriptorByName() { .values() .forEach( fileDescriptor -> { - fileDescriptor - .getMessageTypes() - .forEach( - descriptor -> { - descriptorMap.put(descriptor.getFullName(), descriptor); - }); + fileDescriptor.getMessageTypes().forEach(this::indexDescriptor); }); } + private void indexDescriptor(Descriptors.Descriptor descriptor) { + descriptorMap.put(descriptor.getFullName(), descriptor); + descriptor.getNestedTypes().forEach(this::indexDescriptor); + } + private void indexOptionsByNumber(Collection fileDescriptors) { fieldOptionMap = new HashMap<>(); fileOptionMap = new HashMap<>(); diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java new file mode 100644 index 000000000000..93bf00256698 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java @@ -0,0 +1,844 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMapKeyMessageName; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMapValueMessageName; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMessageName; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withMessageName; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Message; +import java.io.Serializable; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.schemas.logicaltypes.NanosDuration; +import org.apache.beam.sdk.schemas.logicaltypes.NanosInstant; +import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; + +@Experimental(Experimental.Kind.SCHEMAS) +public class ProtoDynamicMessageSchema implements Serializable { + public static final long serialVersionUID = 1L; + + /** + * Context of the schema, the context can be generated from a source schema or descriptors. The + * ability of converting back from Row to proto depends on the type of context. + */ + private final Context context; + + /** The toRow function to convert the Message to a Row. */ + private transient SerializableFunction toRowFunction; + + /** The fromRow function to convert the Row to a Message. */ + private transient SerializableFunction fromRowFunction; + + /** List of field converters for each field in the row. */ + private transient List converters; + + private ProtoDynamicMessageSchema(String messageName, ProtoDomain domain) { + this.context = new DescriptorContext(messageName, domain); + readResolve(); + } + + private ProtoDynamicMessageSchema(Context context) { + this.context = context; + readResolve(); + } + + /** + * Create a new ProtoDynamicMessageSchema from a {@link ProtoDomain} and for a message. The + * message need to be in the domain and needs to be the fully qualified name. + */ + public static ProtoDynamicMessageSchema forDescriptor(ProtoDomain domain, String messageName) { + return new ProtoDynamicMessageSchema(messageName, domain); + } + + /** + * Create a new ProtoDynamicMessageSchema from a {@link ProtoDomain} and for a descriptor. The + * descriptor is only used for it's name, that name will be used for a search in the domain. + */ + public static ProtoDynamicMessageSchema forDescriptor( + ProtoDomain domain, Descriptors.Descriptor descriptor) { + return new ProtoDynamicMessageSchema<>(descriptor.getFullName(), domain); + } + + static ProtoDynamicMessageSchema forContext(Context context, Schema.Field field) { + return new ProtoDynamicMessageSchema<>(context.getSubContext(field)); + } + + static ProtoDynamicMessageSchema forSchema(Schema schema) { + return new ProtoDynamicMessageSchema<>(new Context(schema, Message.class)); + } + + /** Initialize the transient fields after deserialization or construction. */ + private Object readResolve() { + converters = createConverters(context.getSchema()); + toRowFunction = new MessageToRowFunction(); + fromRowFunction = new RowToMessageFunction(); + return this; + } + + Convert createConverter(Schema.Field field) { + Schema.FieldType fieldType = field.getType(); + String messageName = getMessageName(fieldType); + if (messageName != null && messageName.length() > 0) { + Schema.Field valueField = + Schema.Field.of("value", withFieldNumber(Schema.FieldType.BOOLEAN, 1)); + switch (messageName) { + case "google.protobuf.StringValue": + case "google.protobuf.DoubleValue": + case "google.protobuf.FloatValue": + case "google.protobuf.BoolValue": + case "google.protobuf.Int64Value": + case "google.protobuf.Int32Value": + case "google.protobuf.UInt64Value": + case "google.protobuf.UInt32Value": + return new WrapperConvert(field, new PrimitiveConvert(valueField)); + case "google.protobuf.BytesValue": + return new WrapperConvert(field, new BytesConvert(valueField)); + case "google.protobuf.Timestamp": + case "google.protobuf.Duration": + // handled by logical type case + break; + } + } + switch (fieldType.getTypeName()) { + case BYTE: + case INT16: + case INT32: + case INT64: + case FLOAT: + case DOUBLE: + case STRING: + case BOOLEAN: + return new PrimitiveConvert(field); + case BYTES: + return new BytesConvert(field); + case ARRAY: + case ITERABLE: + return new ArrayConvert(this, field); + case MAP: + return new MapConvert(this, field); + case LOGICAL_TYPE: + String identifier = field.getType().getLogicalType().getIdentifier(); + switch (identifier) { + case ProtoSchemaLogicalTypes.Fixed32.IDENTIFIER: + case ProtoSchemaLogicalTypes.Fixed64.IDENTIFIER: + case ProtoSchemaLogicalTypes.SFixed32.IDENTIFIER: + case ProtoSchemaLogicalTypes.SFixed64.IDENTIFIER: + case ProtoSchemaLogicalTypes.SInt32.IDENTIFIER: + case ProtoSchemaLogicalTypes.SInt64.IDENTIFIER: + case ProtoSchemaLogicalTypes.UInt32.IDENTIFIER: + case ProtoSchemaLogicalTypes.UInt64.IDENTIFIER: + return new LogicalTypeConvert(field, fieldType.getLogicalType()); + case NanosInstant.IDENTIFIER: + return new TimestampConvert(field); + case NanosDuration.IDENTIFIER: + return new DurationConvert(field); + case EnumerationType.IDENTIFIER: + return new EnumConvert(field, fieldType.getLogicalType()); + case OneOfType.IDENTIFIER: + return new OneOfConvert(this, field, fieldType.getLogicalType()); + default: + throw new IllegalStateException("Unexpected logical type : " + identifier); + } + case ROW: + return new MessageConvert(this, field); + default: + throw new IllegalStateException("Unexpected value: " + fieldType); + } + } + + private List createConverters(Schema schema) { + List fieldOverlays = new ArrayList<>(); + for (Schema.Field field : schema.getFields()) { + fieldOverlays.add(createConverter(field)); + } + return fieldOverlays; + } + + public Schema getSchema() { + return context.getSchema(); + } + + public SerializableFunction getToRowFunction() { + return toRowFunction; + } + + public SerializableFunction getFromRowFunction() { + return fromRowFunction; + } + + /** + * Context that only has enough information to convert a proto message to a Row. This can be used + * for arbitrary conventions, like decoding messages in proto options. + */ + static class Context implements Serializable { + private final Schema schema; + + /** + * Base class for the protobuf message. Normally this is DynamicMessage, but as this schema + * class is also used to decode protobuf options this can be normal Message instances. + */ + private Class baseClass; + + Context(Schema schema, Class baseClass) { + this.schema = schema; + this.baseClass = baseClass; + } + + public Schema getSchema() { + return schema; + } + + public Class getBaseClass() { + return baseClass; + } + + public DynamicMessage.Builder invokeNewBuilder() { + throw new IllegalStateException("Should not be calling invokeNewBuilder"); + } + + public Context getSubContext(Schema.Field field) { + return new Context(field.getType().getRowSchema(), Message.class); + } + } + + /** + * Context the contains the full {@link ProtoDomain} and a reference to the message name. The full + * domain is needed for creating Rows back to the original proto messages. + */ + static class DescriptorContext extends Context { + private final String messageName; + private final ProtoDomain domain; + private transient Descriptors.Descriptor descriptor; + + DescriptorContext(String messageName, ProtoDomain domain) { + super( + ProtoSchemaTranslator.getSchema(domain.getDescriptor(messageName)), DynamicMessage.class); + this.messageName = messageName; + this.domain = domain; + } + + @Override + public DynamicMessage.Builder invokeNewBuilder() { + if (descriptor == null) { + descriptor = domain.getDescriptor(messageName); + } + return DynamicMessage.newBuilder(descriptor); + } + + @Override + public Context getSubContext(Schema.Field field) { + String messageName = getMessageName(field.getType()); + return new DescriptorContext(messageName, domain); + } + } + + /** + * Base converter class for converting from proto values to row values. The converter mainly works + * on fields in proto messages but also has methods to convert individual elements (example, for + * elements in Lists or Maps). + */ + abstract static class Convert { + private int number; + + Convert(Schema.Field field) { + try { + this.number = getFieldNumber(field.getType()); + } catch (NumberFormatException e) { + this.number = -1; + } + } + + FieldDescriptor getFieldDescriptor(Message message) { + return message.getDescriptorForType().findFieldByNumber(number); + } + + FieldDescriptor getFieldDescriptor(Message.Builder message) { + return message.getDescriptorForType().findFieldByNumber(number); + } + + /** Get a proto field and convert it into a row value. */ + abstract Object getFromProtoMessage(Message message); + + /** Convert a proto value into a row value. */ + abstract ValueT convertFromProtoValue(Object object); + + /** Convert a row value and set it on a proto message. */ + abstract void setOnProtoMessage(Message.Builder object, InT value); + + /** Convert a row value into a proto value. */ + abstract Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value); + } + + /** Converter for primitive proto values. */ + static class PrimitiveConvert extends Convert { + PrimitiveConvert(Schema.Field field) { + super(field); + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + + @Override + Object convertFromProtoValue(Object object) { + return object; + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + message.setField(getFieldDescriptor(message), value); + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** + * Converter for Bytes. Protobuf Bytes are natively represented as ByteStrings that requires + * special handling for byte[] of size 0. + */ + static class BytesConvert extends PrimitiveConvert { + BytesConvert(Schema.Field field) { + super(field); + } + + @Override + Object convertFromProtoValue(Object object) { + // return object; + return ((ByteString) object).toByteArray(); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null && ((byte[]) value).length > 0) { + // Protobuf messages BYTES doesn't like empty bytes?! + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertToProtoValue(fieldDescriptor, value)); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + if (value != null) { + return ByteString.copyFrom((byte[]) value); + } + return null; + } + } + + /** + * Specific converter for Proto Wrapper values as they are translated into nullable row values. + */ + static class WrapperConvert extends Convert { + private Convert valueConvert; + + WrapperConvert(Schema.Field field, Convert valueConvert) { + super(field); + this.valueConvert = valueConvert; + } + + @Override + Object getFromProtoMessage(Message message) { + if (message.hasField(getFieldDescriptor(message))) { + Message wrapper = (Message) message.getField(getFieldDescriptor(message)); + return valueConvert.getFromProtoMessage(wrapper); + } + return null; + } + + @Override + Object convertFromProtoValue(Object object) { + return object; + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + DynamicMessage.Builder builder = + DynamicMessage.newBuilder(getFieldDescriptor(message).getMessageType()); + valueConvert.setOnProtoMessage(builder, value); + message.setField(getFieldDescriptor(message), builder.build()); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + static class TimestampConvert extends Convert { + + TimestampConvert(Schema.Field field) { + super(field); + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + if (message.hasField(fieldDescriptor)) { + Message wrapper = (Message) message.getField(fieldDescriptor); + return convertFromProtoValue(wrapper); + } + return null; + } + + @Override + Object convertFromProtoValue(Object object) { + Message timestamp = (Message) object; + Descriptors.Descriptor timestampDescriptor = timestamp.getDescriptorForType(); + FieldDescriptor secondField = timestampDescriptor.findFieldByNumber(1); + FieldDescriptor nanoField = timestampDescriptor.findFieldByNumber(2); + long second = (long) timestamp.getField(secondField); + int nano = (int) timestamp.getField(nanoField); + return Instant.ofEpochSecond(second, nano); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertToProtoValue(fieldDescriptor, value)); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + Row row = (Row) value; + return com.google.protobuf.Timestamp.newBuilder() + .setSeconds(row.getInt64(0)) + .setNanos(row.getInt32(1)) + .build(); + } + } + + static class DurationConvert extends Convert { + + DurationConvert(Schema.Field field) { + super(field); + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + if (message.hasField(fieldDescriptor)) { + Message wrapper = (Message) message.getField(fieldDescriptor); + return convertFromProtoValue(wrapper); + } + return null; + } + + @Override + Duration convertFromProtoValue(Object object) { + Message timestamp = (Message) object; + Descriptors.Descriptor timestampDescriptor = timestamp.getDescriptorForType(); + FieldDescriptor secondField = timestampDescriptor.findFieldByNumber(1); + FieldDescriptor nanoField = timestampDescriptor.findFieldByNumber(2); + long second = (long) timestamp.getField(secondField); + int nano = (int) timestamp.getField(nanoField); + return Duration.ofSeconds(second, nano); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertToProtoValue(fieldDescriptor, value)); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + Row row = (Row) value; + return com.google.protobuf.Duration.newBuilder() + .setSeconds(row.getInt64(0)) + .setNanos(row.getInt32(1)) + .build(); + } + } + + static class MessageConvert extends Convert { + private final SerializableFunction fromRowFunction; + private final SerializableFunction toRowFunction; + + MessageConvert(ProtoDynamicMessageSchema rootProtoSchema, Schema.Field field) { + super(field); + ProtoDynamicMessageSchema protoSchema = + ProtoDynamicMessageSchema.forContext(rootProtoSchema.context, field); + toRowFunction = protoSchema.getToRowFunction(); + fromRowFunction = protoSchema.getFromRowFunction(); + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + if (message.hasField(fieldDescriptor)) { + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + return null; + } + + @Override + Object convertFromProtoValue(Object object) { + return toRowFunction.apply(object); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertToProtoValue(fieldDescriptor, value)); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return fromRowFunction.apply(value); + } + } + + /** + * Proto has a well defined way of storing maps, by having a Message with two fields, named "key" + * and "value" in a repeatable field. This overlay translates between Row.map and the Protobuf + * map. + */ + static class MapConvert extends Convert { + private Convert key; + private Convert value; + + MapConvert(ProtoDynamicMessageSchema protoSchema, Schema.Field field) { + super(field); + Schema.FieldType fieldType = field.getType(); + key = + protoSchema.createConverter( + Schema.Field.of( + "KEY", + withMessageName(fieldType.getMapKeyType(), getMapKeyMessageName(fieldType)))); + value = + protoSchema.createConverter( + Schema.Field.of( + "VALUE", + withMessageName(fieldType.getMapValueType(), getMapValueMessageName(fieldType)))); + } + + @Override + Map getFromProtoMessage(Message message) { + List list = (List) message.getField(getFieldDescriptor(message)); + if (list.size() == 0) { + return null; + } + Map rowMap = new HashMap<>(); + list.forEach( + entryMessage -> { + Descriptors.Descriptor entryDescriptor = entryMessage.getDescriptorForType(); + FieldDescriptor keyFieldDescriptor = entryDescriptor.findFieldByName("key"); + FieldDescriptor valueFieldDescriptor = entryDescriptor.findFieldByName("value"); + rowMap.put( + key.convertFromProtoValue(entryMessage.getField(keyFieldDescriptor)), + this.value.convertFromProtoValue(entryMessage.getField(valueFieldDescriptor))); + }); + return rowMap; + } + + @Override + Map convertFromProtoValue(Object object) { + throw new RuntimeException("?"); + } + + @Override + void setOnProtoMessage(Message.Builder message, Map map) { + if (map != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + List messageMap = new ArrayList<>(); + map.forEach( + (k, v) -> { + DynamicMessage.Builder builder = + DynamicMessage.newBuilder(fieldDescriptor.getMessageType()); + FieldDescriptor keyFieldDescriptor = + fieldDescriptor.getMessageType().findFieldByName("key"); + builder.setField( + keyFieldDescriptor, this.key.convertToProtoValue(keyFieldDescriptor, k)); + FieldDescriptor valueFieldDescriptor = + fieldDescriptor.getMessageType().findFieldByName("value"); + builder.setField( + valueFieldDescriptor, value.convertToProtoValue(valueFieldDescriptor, v)); + messageMap.add(builder.build()); + }); + message.setField(fieldDescriptor, messageMap); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + static class ArrayConvert extends Convert { + private Convert element; + + ArrayConvert(ProtoDynamicMessageSchema protoSchema, Schema.Field field) { + super(field); + Schema.FieldType collectionElementType = field.getType().getCollectionElementType(); + this.element = + protoSchema.createConverter( + Schema.Field.of( + "ELEMENT", + withMessageName(collectionElementType, getMessageName(field.getType())))); + } + + @Override + List getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + + @Override + List convertFromProtoValue(Object value) { + List list = (List) value; + List arrayList = new ArrayList<>(); + list.forEach( + entry -> { + arrayList.add(element.convertFromProtoValue(entry)); + }); + return arrayList; + } + + @Override + void setOnProtoMessage(Message.Builder message, List list) { + if (list != null) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + List targetList = new ArrayList<>(); + list.forEach( + (e) -> { + targetList.add(element.convertToProtoValue(fieldDescriptor, e)); + }); + message.setField(fieldDescriptor, targetList); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + /** Enum overlay handles the conversion between a string and a ProtoBuf Enum. */ + static class EnumConvert extends Convert { + EnumerationType logicalType; + + EnumConvert(Schema.Field field, Schema.LogicalType logicalType) { + super(field); + this.logicalType = (EnumerationType) logicalType; + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + + @Override + EnumerationType.Value convertFromProtoValue(Object in) { + return logicalType.valueOf(((Descriptors.EnumValueDescriptor) in).getNumber()); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + message.setField(fieldDescriptor, convertToProtoValue(fieldDescriptor, value)); + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + Descriptors.EnumDescriptor enumType = fieldDescriptor.getEnumType(); + return enumType.findValueByNumber((Integer) value); + } + } + + /** Convert Proto oneOf fields into the {@link OneOfType} logical type. */ + static class OneOfConvert extends Convert { + OneOfType logicalType; + Map oneOfConvert = new HashMap<>(); + + OneOfConvert( + ProtoDynamicMessageSchema protoSchema, Schema.Field field, Schema.LogicalType logicalType) { + super(field); + this.logicalType = (OneOfType) logicalType; + for (Schema.Field oneOfField : this.logicalType.getOneOfSchema().getFields()) { + int fieldNumber = getFieldNumber(oneOfField.getType()); + oneOfConvert.put( + fieldNumber, new NullableConvert(oneOfField, protoSchema.createConverter(oneOfField))); + } + } + + @Override + Object getFromProtoMessage(Message message) { + for (Map.Entry entry : this.oneOfConvert.entrySet()) { + Object value = entry.getValue().getFromProtoMessage(message); + if (value != null) { + return logicalType.createValue(entry.getKey(), value); + } + } + return null; + } + + @Override + OneOfType.Value convertFromProtoValue(Object in) { + throw new IllegalStateException("Value conversion can't be done outside a protobuf message"); + } + + @Override + void setOnProtoMessage(Message.Builder message, Row value) { + OneOfType.Value oneOf = logicalType.toInputType(value); + int caseIndex = oneOf.getCaseType().getValue(); + oneOfConvert.get(caseIndex).setOnProtoMessage(message, oneOf.getValue()); + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + throw new IllegalStateException("Value conversion can't be done outside a protobuf message"); + } + } + + /** + * This overlay handles nullable fields. If a primitive field needs to be nullable this overlay is + * wrapped around the original overlay. + */ + static class NullableConvert extends Convert { + + private Convert fieldOverlay; + + NullableConvert(Schema.Field field, Convert fieldOverlay) { + super(field); + this.fieldOverlay = fieldOverlay; + } + + @Override + Object getFromProtoMessage(Message message) { + if (message.hasField(getFieldDescriptor(message))) { + return fieldOverlay.getFromProtoMessage(message); + } + return null; + } + + @Override + Object convertFromProtoValue(Object object) { + throw new IllegalStateException("Value conversion can't be done outside a protobuf message"); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + if (value != null) { + fieldOverlay.setOnProtoMessage(message, value); + } + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + throw new IllegalStateException("Value conversion can't be done outside a protobuf message"); + } + } + + static class LogicalTypeConvert extends Convert { + + private Schema.LogicalType logicalType; + + LogicalTypeConvert(Schema.Field field, Schema.LogicalType logicalType) { + super(field); + this.logicalType = logicalType; + } + + @Override + Object getFromProtoMessage(Message message) { + FieldDescriptor fieldDescriptor = getFieldDescriptor(message); + return convertFromProtoValue(message.getField(fieldDescriptor)); + } + + @Override + Object convertFromProtoValue(Object object) { + return logicalType.toBaseType(object); + } + + @Override + void setOnProtoMessage(Message.Builder message, Object value) { + message.setField(getFieldDescriptor(message), value); + } + + @Override + Object convertToProtoValue(FieldDescriptor fieldDescriptor, Object value) { + return value; + } + } + + private class MessageToRowFunction implements SerializableFunction { + + private MessageToRowFunction() {} + + @Override + public Row apply(T input) { + Schema schema = context.getSchema(); + Row.Builder builder = Row.withSchema(schema); + for (Convert convert : converters) { + builder.addValue(convert.getFromProtoMessage((Message) input)); + } + return builder.build(); + } + } + + private class RowToMessageFunction implements SerializableFunction { + + private RowToMessageFunction() {} + + @Override + public T apply(Row input) { + DynamicMessage.Builder builder = context.invokeNewBuilder(); + Iterator values = input.getValues().iterator(); + Iterator convertIterator = converters.iterator(); + + for (int i = 0; i < input.getValues().size(); i++) { + Convert convert = convertIterator.next(); + Object value = values.next(); + convert.setOnProtoMessage(builder, value); + } + return (T) builder.build(); + } + } +} diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java index 7b9930832806..bc7503fccc65 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java @@ -132,6 +132,13 @@ public class ProtoSchemaTranslator { /** This METADATA tag is used to store the field number of a proto tag. */ public static final String PROTO_NUMBER_METADATA_TAG = "PROTO_NUMBER"; + public static final String PROTO_MESSAGE_NAME_METADATA_TAG = "PROTO_MESSAGE_NAME"; + + public static final String PROTO_MAP_KEY_MESSAGE_NAME_METADATA_TAG = "PROTO_MAP_KEY_MESSAGE_NAME"; + + public static final String PROTO_MAP_VALUE_MESSAGE_NAME_METADATA_TAG = + "PROTO_MAP_VALUE_MESSAGE_NAME"; + /** Attach a proto field number to a type. */ public static FieldType withFieldNumber(FieldType fieldType, int index) { return fieldType.withMetadata(PROTO_NUMBER_METADATA_TAG, Long.toString(index)); @@ -142,12 +149,42 @@ public static int getFieldNumber(FieldType fieldType) { return Integer.parseInt(fieldType.getMetadataString(PROTO_NUMBER_METADATA_TAG)); } + /** Attach the name of the message to a type. */ + public static FieldType withMessageName(FieldType fieldType, String messageName) { + return fieldType.withMetadata(PROTO_MESSAGE_NAME_METADATA_TAG, messageName); + } + + /** Return the message name for a type. */ + public static String getMessageName(FieldType fieldType) { + return fieldType.getMetadataString(PROTO_MESSAGE_NAME_METADATA_TAG); + } + + /** Attach the name of the message to a map key. */ + public static FieldType withMapKeyMessageName(FieldType fieldType, String messageName) { + return fieldType.withMetadata(PROTO_MAP_KEY_MESSAGE_NAME_METADATA_TAG, messageName); + } + + /** Return the message name for a map key. */ + public static String getMapKeyMessageName(FieldType fieldType) { + return fieldType.getMetadataString(PROTO_MAP_KEY_MESSAGE_NAME_METADATA_TAG); + } + + /** Attach the name of the message to a map value. */ + public static FieldType withMapValueMessageName(FieldType fieldType, String messageName) { + return fieldType.withMetadata(PROTO_MAP_VALUE_MESSAGE_NAME_METADATA_TAG, messageName); + } + + /** Return the message name for a map value. */ + public static String getMapValueMessageName(FieldType fieldType) { + return fieldType.getMetadataString(PROTO_MAP_VALUE_MESSAGE_NAME_METADATA_TAG); + } + /** Return a Beam scheam representing a proto class. */ public static Schema getSchema(Class clazz) { return getSchema(ProtobufUtil.getDescriptorForClass(clazz)); } - private static Schema getSchema(Descriptors.Descriptor descriptor) { + static Schema getSchema(Descriptors.Descriptor descriptor) { Set oneOfFields = Sets.newHashSet(); List fields = Lists.newArrayListWithCapacity(descriptor.getFields().size()); for (OneofDescriptor oneofDescriptor : descriptor.getOneofs()) { @@ -157,8 +194,7 @@ private static Schema getSchema(Descriptors.Descriptor descriptor) { oneOfFields.add(fieldDescriptor.getNumber()); // Store proto field number in metadata. FieldType fieldType = - withFieldNumber( - beamFieldTypeFromProtoField(fieldDescriptor), fieldDescriptor.getNumber()); + withMetaData(beamFieldTypeFromProtoField(fieldDescriptor), fieldDescriptor); subFields.add(Field.nullable(fieldDescriptor.getName(), fieldType)); checkArgument( enumIds.putIfAbsent(fieldDescriptor.getName(), fieldDescriptor.getNumber()) == null); @@ -171,14 +207,34 @@ private static Schema getSchema(Descriptors.Descriptor descriptor) { if (!oneOfFields.contains(fieldDescriptor.getNumber())) { // Store proto field number in metadata. FieldType fieldType = - withFieldNumber( - beamFieldTypeFromProtoField(fieldDescriptor), fieldDescriptor.getNumber()); + withMetaData(beamFieldTypeFromProtoField(fieldDescriptor), fieldDescriptor); fields.add(Field.of(fieldDescriptor.getName(), fieldType)); } } return Schema.builder().addFields(fields).build(); } + private static FieldType withMetaData( + FieldType inType, Descriptors.FieldDescriptor fieldDescriptor) { + FieldType fieldType = withFieldNumber(inType, fieldDescriptor.getNumber()); + if (fieldDescriptor.isMapField()) { + FieldDescriptor keyFieldDescriptor = fieldDescriptor.getMessageType().findFieldByName("key"); + FieldDescriptor valueFieldDescriptor = + fieldDescriptor.getMessageType().findFieldByName("value"); + if ((keyFieldDescriptor.getType() == FieldDescriptor.Type.MESSAGE)) { + fieldType = + withMapKeyMessageName(fieldType, keyFieldDescriptor.getMessageType().getFullName()); + } + if ((valueFieldDescriptor.getType() == FieldDescriptor.Type.MESSAGE)) { + fieldType = + withMapValueMessageName(fieldType, valueFieldDescriptor.getMessageType().getFullName()); + } + } else if (fieldDescriptor.getType() == FieldDescriptor.Type.MESSAGE) { + return withMessageName(fieldType, fieldDescriptor.getMessageType().getFullName()); + } + return fieldType; + } + private static FieldType beamFieldTypeFromProtoField( Descriptors.FieldDescriptor protoFieldDescriptor) { FieldType fieldType = null; diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchemaTest.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchemaTest.java new file mode 100644 index 000000000000..14bd131982cb --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchemaTest.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.protobuf; + +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_BOOL; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_INT32; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_PRIMITIVE; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_PROTO_STRING; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_BOOL; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_INT32; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_PRIMITIVE; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_ROW_STRING; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.ONEOF_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.OUTER_ONEOF_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.OUTER_ONEOF_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.OUTER_ONEOF_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.PRIMITIVE_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.PRIMITIVE_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.PRIMITIVE_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_PROTO; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_ROW; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_SCHEMA; +import static org.junit.Assert.assertEquals; + +import com.google.protobuf.Descriptors; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.EnumMessage; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.MapPrimitive; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.Nested; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.OneOf; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.OuterOneOf; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.Primitive; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.RepeatPrimitive; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.WktMessage; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Collection of tests for values on Protobuf Messages and Rows. */ +@RunWith(JUnit4.class) +public class ProtoDynamicMessageSchemaTest { + + private ProtoDynamicMessageSchema schemaFromDescriptor(Descriptors.Descriptor descriptor) { + ProtoDomain domain = ProtoDomain.buildFrom(descriptor); + return ProtoDynamicMessageSchema.forDescriptor(domain, descriptor); + } + + private DynamicMessage toDynamic(Message message) throws InvalidProtocolBufferException { + return DynamicMessage.parseFrom(message.getDescriptorForType(), message.toByteArray()); + } + + @Test + public void testPrimitiveSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Primitive.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(PRIMITIVE_SCHEMA, schema); + } + + @Test + public void testPrimitiveProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Primitive.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(PRIMITIVE_ROW, toRow.apply(toDynamic(PRIMITIVE_PROTO))); + } + + @Test + public void testPrimitiveRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Primitive.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(PRIMITIVE_PROTO.toString(), fromRow.apply(PRIMITIVE_ROW).toString()); + } + + @Test + public void testRepeatedSchema() { + ProtoDynamicMessageSchema schemaProvider = + schemaFromDescriptor(RepeatPrimitive.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(REPEATED_SCHEMA, schema); + } + + @Test + public void testRepeatedProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = + schemaFromDescriptor(RepeatPrimitive.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(REPEATED_ROW, toRow.apply(toDynamic(REPEATED_PROTO))); + } + + @Test + public void testRepeatedRowToProto() { + ProtoDynamicMessageSchema schemaProvider = + schemaFromDescriptor(RepeatPrimitive.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(REPEATED_PROTO.toString(), fromRow.apply(REPEATED_ROW).toString()); + } + + // Test map type + @Test + public void testMapSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(MapPrimitive.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(MAP_PRIMITIVE_SCHEMA, schema); + } + + @Test + public void testMapProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(MapPrimitive.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(MAP_PRIMITIVE_ROW, toRow.apply(toDynamic(MAP_PRIMITIVE_PROTO))); + } + + @Test + public void testMapRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(MapPrimitive.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(MAP_PRIMITIVE_PROTO.toString(), fromRow.apply(MAP_PRIMITIVE_ROW).toString()); + } + + @Test + public void testNestedSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Nested.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(NESTED_SCHEMA, schema); + } + + @Test + public void testNestedProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Nested.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(NESTED_ROW, toRow.apply(toDynamic(NESTED_PROTO))); + } + + @Test + public void testNestedRowToProto() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(Nested.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + // equality doesn't work between dynamic messages and other, + // so we compare string representation + assertEquals(NESTED_PROTO.toString(), fromRow.apply(NESTED_ROW).toString()); + } + + @Test + public void testOneOfSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OneOf.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(ONEOF_SCHEMA, schema); + } + + @Test + public void testOneOfProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OneOf.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + // equality doesn't work between dynamic messages and other, + // so we compare string representation + assertEquals(ONEOF_ROW_INT32.toString(), toRow.apply(toDynamic(ONEOF_PROTO_INT32)).toString()); + assertEquals(ONEOF_ROW_BOOL.toString(), toRow.apply(toDynamic(ONEOF_PROTO_BOOL)).toString()); + assertEquals( + ONEOF_ROW_STRING.toString(), toRow.apply(toDynamic(ONEOF_PROTO_STRING)).toString()); + assertEquals( + ONEOF_ROW_PRIMITIVE.toString(), toRow.apply(toDynamic(ONEOF_PROTO_PRIMITIVE)).toString()); + } + + @Test + public void testOneOfRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OneOf.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(ONEOF_PROTO_INT32.toString(), fromRow.apply(ONEOF_ROW_INT32).toString()); + assertEquals(ONEOF_PROTO_BOOL.toString(), fromRow.apply(ONEOF_ROW_BOOL).toString()); + assertEquals(ONEOF_PROTO_STRING.toString(), fromRow.apply(ONEOF_ROW_STRING).toString()); + assertEquals(ONEOF_PROTO_PRIMITIVE.toString(), fromRow.apply(ONEOF_ROW_PRIMITIVE).toString()); + } + + @Test + public void testOuterOneOfSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OuterOneOf.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(OUTER_ONEOF_SCHEMA, schema); + } + + @Test + public void testOuterOneOfProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OuterOneOf.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + // equality doesn't work between dynamic messages and other, + // so we compare string representation + assertEquals(OUTER_ONEOF_ROW.toString(), toRow.apply(toDynamic(OUTER_ONEOF_PROTO)).toString()); + } + + @Test + public void testOuterOneOfRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OuterOneOf.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(OUTER_ONEOF_PROTO.toString(), fromRow.apply(OUTER_ONEOF_ROW).toString()); + } + + private static final EnumerationType ENUM_TYPE = + EnumerationType.create(ImmutableMap.of("ZERO", 0, "TWO", 2, "THREE", 3)); + private static final Schema ENUM_SCHEMA = + Schema.builder() + .addField( + "enum", + withFieldNumber(Schema.FieldType.logicalType(ENUM_TYPE).withNullable(false), 1)) + .build(); + private static final Row ENUM_ROW = + Row.withSchema(ENUM_SCHEMA).addValues(ENUM_TYPE.valueOf("TWO")).build(); + private static final EnumMessage ENUM_PROTO = + EnumMessage.newBuilder().setEnum(EnumMessage.Enum.TWO).build(); + + @Test + public void testEnumSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(EnumMessage.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(ENUM_SCHEMA, schema); + } + + @Test + public void testEnumProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(EnumMessage.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(ENUM_ROW, toRow.apply(toDynamic(ENUM_PROTO))); + } + + @Test + public void testEnumRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(EnumMessage.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(ENUM_PROTO.toString(), fromRow.apply(ENUM_ROW).toString()); + } + + @Test + public void testWktMessageSchema() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(WktMessage.getDescriptor()); + Schema schema = schemaProvider.getSchema(); + assertEquals(WKT_MESSAGE_SCHEMA, schema); + } + + @Test + public void testWktProtoToRow() throws InvalidProtocolBufferException { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(WktMessage.getDescriptor()); + SerializableFunction toRow = schemaProvider.getToRowFunction(); + assertEquals(WKT_MESSAGE_ROW, toRow.apply(toDynamic(WKT_MESSAGE_PROTO))); + } + + @Test + public void testWktRowToProto() { + ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(WktMessage.getDescriptor()); + SerializableFunction fromRow = schemaProvider.getFromRowFunction(); + assertEquals(WKT_MESSAGE_PROTO.toString(), fromRow.apply(WKT_MESSAGE_ROW).toString()); + } +} diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java index 10ed0ece2706..862d637c8506 100644 --- a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java @@ -19,6 +19,8 @@ import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getFieldNumber; import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withMapValueMessageName; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withMessageName; import com.google.protobuf.BoolValue; import com.google.protobuf.ByteString; @@ -276,12 +278,21 @@ class TestProtoSchemas { static final Schema NESTED_SCHEMA = Schema.builder() .addField( - "nested", withFieldNumber(FieldType.row(PRIMITIVE_SCHEMA).withNullable(true), 1)) + "nested", + withMessageName( + withFieldNumber(FieldType.row(PRIMITIVE_SCHEMA).withNullable(true), 1), + "proto3_schema_messages.Primitive")) .addField( - "nested_list", withFieldNumber(FieldType.array(FieldType.row(PRIMITIVE_SCHEMA)), 2)) + "nested_list", + withMessageName( + withFieldNumber(FieldType.array(FieldType.row(PRIMITIVE_SCHEMA)), 2), + "proto3_schema_messages.Primitive")) .addField( "nested_map", - withFieldNumber(FieldType.map(FieldType.STRING, FieldType.row(PRIMITIVE_SCHEMA)), 3)) + withMapValueMessageName( + withFieldNumber( + FieldType.map(FieldType.STRING, FieldType.row(PRIMITIVE_SCHEMA)), 3), + "proto3_schema_messages.Primitive")) .build(); // A sample instance of the row. @@ -306,7 +317,11 @@ class TestProtoSchemas { Field.of("oneof_int32", withFieldNumber(FieldType.INT32, 2)), Field.of("oneof_bool", withFieldNumber(FieldType.BOOLEAN, 3)), Field.of("oneof_string", withFieldNumber(FieldType.STRING, 4)), - Field.of("oneof_primitive", withFieldNumber(FieldType.row(PRIMITIVE_SCHEMA), 5))); + Field.of( + "oneof_primitive", + withMessageName( + withFieldNumber(FieldType.row(PRIMITIVE_SCHEMA), 5), + "proto3_schema_messages.Primitive"))); private static final Map ONE_OF_ENUM_MAP = ONEOF_FIELDS.stream() .collect(Collectors.toMap(Field::getName, f -> getFieldNumber(f.getType()))); @@ -349,7 +364,10 @@ class TestProtoSchemas { // The schema for the OuterOneOf proto. private static final List OUTER_ONEOF_FIELDS = ImmutableList.of( - Field.of("oneof_oneof", withFieldNumber(FieldType.row(ONEOF_SCHEMA), 1)), + Field.of( + "oneof_oneof", + withMessageName( + withFieldNumber(FieldType.row(ONEOF_SCHEMA), 1), "proto3_schema_messages.OneOf")), Field.of("oneof_int32", withFieldNumber(FieldType.INT32, 2))); private static final Map OUTER_ONE_OF_ENUM_MAP = OUTER_ONEOF_FIELDS.stream() @@ -371,19 +389,47 @@ class TestProtoSchemas { static final Schema WKT_MESSAGE_SCHEMA = Schema.builder() - .addNullableField("double", withFieldNumber(FieldType.DOUBLE, 1)) - .addNullableField("float", withFieldNumber(FieldType.FLOAT, 2)) - .addNullableField("int32", withFieldNumber(FieldType.INT32, 3)) - .addNullableField("int64", withFieldNumber(FieldType.INT64, 4)) - .addNullableField("uint32", withFieldNumber(FieldType.logicalType(new UInt32()), 5)) - .addNullableField("uint64", withFieldNumber(FieldType.logicalType(new UInt64()), 6)) - .addNullableField("bool", withFieldNumber(FieldType.BOOLEAN, 13)) - .addNullableField("string", withFieldNumber(FieldType.STRING, 14)) - .addNullableField("bytes", withFieldNumber(FieldType.BYTES, 15)) .addNullableField( - "timestamp", withFieldNumber(FieldType.logicalType(new NanosInstant()), 16)) + "double", + withMessageName(withFieldNumber(FieldType.DOUBLE, 1), "google.protobuf.DoubleValue")) + .addNullableField( + "float", + withMessageName(withFieldNumber(FieldType.FLOAT, 2), "google.protobuf.FloatValue")) + .addNullableField( + "int32", + withMessageName(withFieldNumber(FieldType.INT32, 3), "google.protobuf.Int32Value")) + .addNullableField( + "int64", + withMessageName(withFieldNumber(FieldType.INT64, 4), "google.protobuf.Int64Value")) + .addNullableField( + "uint32", + withMessageName( + withFieldNumber(FieldType.logicalType(new UInt32()), 5), + "google.protobuf.UInt32Value")) + .addNullableField( + "uint64", + withMessageName( + withFieldNumber(FieldType.logicalType(new UInt64()), 6), + "google.protobuf.UInt64Value")) + .addNullableField( + "bool", + withMessageName(withFieldNumber(FieldType.BOOLEAN, 13), "google.protobuf.BoolValue")) + .addNullableField( + "string", + withMessageName(withFieldNumber(FieldType.STRING, 14), "google.protobuf.StringValue")) + .addNullableField( + "bytes", + withMessageName(withFieldNumber(FieldType.BYTES, 15), "google.protobuf.BytesValue")) + .addNullableField( + "timestamp", + withMessageName( + withFieldNumber(FieldType.logicalType(new NanosInstant()), 16), + "google.protobuf.Timestamp")) .addNullableField( - "duration", withFieldNumber(FieldType.logicalType(new NanosDuration()), 17)) + "duration", + withMessageName( + withFieldNumber(FieldType.logicalType(new NanosDuration()), 17), + "google.protobuf.Duration")) .build(); // A sample instance of the row. static final Instant JAVA_NOW = Instant.now(); @@ -426,7 +472,9 @@ class TestProtoSchemas { Schema.builder() .addField( "nested", - withFieldNumber(FieldType.row(OPTIONAL_PRIMITIVE_SCHEMA), 1).withNullable(true)) + withMessageName( + withFieldNumber(FieldType.row(OPTIONAL_PRIMITIVE_SCHEMA), 1).withNullable(true), + "proto2_schema_messages.OptionalPrimitive")) .build(); // A sample instance of the proto. @@ -438,7 +486,9 @@ class TestProtoSchemas { Schema.builder() .addField( "nested", - withFieldNumber(FieldType.row(REQUIRED_PRIMITIVE_SCHEMA), 1).withNullable(false)) + withMessageName( + withFieldNumber(FieldType.row(REQUIRED_PRIMITIVE_SCHEMA), 1).withNullable(false), + "proto2_schema_messages.RequiredPrimitive")) .build(); // A sample instance of the proto.