diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java index 79c18199a97b..37c643af9b83 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/logicaltypes/OneOfType.java @@ -55,9 +55,7 @@ private OneOfType(List fields) { private OneOfType(List fields, @Nullable Map enumMap) { List nullableFields = - fields.stream() - .map(f -> Field.nullable(f.getName(), f.getType())) - .collect(Collectors.toList()); + fields.stream().map(f -> f.withNullable(true)).collect(Collectors.toList()); if (enumMap != null) { nullableFields.stream().forEach(f -> checkArgument(enumMap.containsKey(f.getName()))); enumerationType = EnumerationType.create(enumMap); diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java index b6814522596f..9ac73258343d 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoByteBuddyUtils.java @@ -922,7 +922,7 @@ private static FieldValueGetter createGetter( fieldValueTypeSupplier.get(clazz, oneOfType.getOneOfSchema()).stream() .collect(Collectors.toMap(FieldValueTypeInformation::getName, f -> f)); for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { - int protoFieldIndex = getFieldNumber(oneOfField.getType()); + int protoFieldIndex = getFieldNumber(oneOfField); FieldValueGetter oneOfFieldGetter = createGetter( oneOfFieldTypes.get(oneOfField.getName()), @@ -993,7 +993,7 @@ FieldValueSetter getProtoFieldValueSetter( TreeMap> oneOfSetters = Maps.newTreeMap(); for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) { FieldValueSetter setter = getProtoFieldValueSetter(oneOfField, methods, builderClass); - oneOfSetters.put(getFieldNumber(oneOfField.getType()), setter); + oneOfSetters.put(getFieldNumber(oneOfField), setter); } return createOneOfSetter(field.getName(), oneOfSetters, builderClass); } else { diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java index c13cf882c547..3da07ccdf7a8 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomain.java @@ -17,8 +17,20 @@ */ package org.apache.beam.sdk.extensions.protobuf; +import com.google.protobuf.Any; +import com.google.protobuf.Api; import com.google.protobuf.DescriptorProtos; import com.google.protobuf.Descriptors; +import com.google.protobuf.Duration; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Empty; +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.FieldMask; +import com.google.protobuf.Int32Value; +import com.google.protobuf.SourceContext; +import com.google.protobuf.Struct; +import com.google.protobuf.Timestamp; +import com.google.protobuf.Type; import java.io.IOException; import java.io.InputStream; import java.io.ObjectInputStream; @@ -77,19 +89,53 @@ private static Map extractProtoMap private static Descriptors.FileDescriptor convertToFileDescriptorMap( String name, Map inMap, - Map outMap) { + Map outMap, + ExtensionRegistry registry) { if (outMap.containsKey(name)) { return outMap.get(name); } DescriptorProtos.FileDescriptorProto fileDescriptorProto = inMap.get(name); if (fileDescriptorProto == null) { - if ("google/protobuf/descriptor.proto".equals(name)) { - outMap.put( - "google/protobuf/descriptor.proto", - DescriptorProtos.FieldOptions.getDescriptor().getFile()); - return DescriptorProtos.FieldOptions.getDescriptor().getFile(); + Descriptors.FileDescriptor fd; + switch (name) { + case "google/protobuf/descriptor.proto": + fd = DescriptorProtos.FieldOptions.getDescriptor().getFile(); + break; + case "google/protobuf/wrappers.proto": + fd = Int32Value.getDescriptor().getFile(); + break; + case "google/protobuf/timestamp.proto": + fd = Timestamp.getDescriptor().getFile(); + break; + case "google/protobuf/duration.proto": + fd = Duration.getDescriptor().getFile(); + break; + case "google/protobuf/any.proto": + fd = Any.getDescriptor().getFile(); + break; + case "google/protobuf/api.proto": + fd = Api.getDescriptor().getFile(); + break; + case "google/protobuf/empty.proto": + fd = Empty.getDescriptor().getFile(); + break; + case "google/protobuf/field_mask.proto": + fd = FieldMask.getDescriptor().getFile(); + break; + case "google/protobuf/source_context.proto": + fd = SourceContext.getDescriptor().getFile(); + break; + case "google/protobuf/struct.proto": + fd = Struct.getDescriptor().getFile(); + break; + case "google/protobuf/type.proto": + fd = Type.getDescriptor().getFile(); + break; + default: + return null; } - return null; + outMap.put(name, fd); + return fd; } else { List dependencies = new ArrayList<>(); if (fileDescriptorProto.getDependencyCount() > 0) { @@ -98,7 +144,7 @@ private static Descriptors.FileDescriptor convertToFileDescriptorMap( .forEach( dependencyName -> { Descriptors.FileDescriptor fileDescriptor = - convertToFileDescriptorMap(dependencyName, inMap, outMap); + convertToFileDescriptorMap(dependencyName, inMap, outMap, registry); if (fileDescriptor != null) { dependencies.add(fileDescriptor); } @@ -108,6 +154,18 @@ private static Descriptors.FileDescriptor convertToFileDescriptorMap( Descriptors.FileDescriptor fileDescriptor = Descriptors.FileDescriptor.buildFrom( fileDescriptorProto, dependencies.toArray(new Descriptors.FileDescriptor[0])); + fileDescriptor + .getExtensions() + .forEach( + extension -> { + if (extension.getType() == Descriptors.FieldDescriptor.Type.MESSAGE) { + registry.add( + extension, DynamicMessage.newBuilder(extension.getMessageType()).build()); + } else { + registry.add(extension); + } + }); + Descriptors.FileDescriptor.internalUpdateFileDescriptor(fileDescriptor, registry); outMap.put(name, fileDescriptor); return fileDescriptor; } catch (Descriptors.DescriptorValidationException e) { @@ -147,10 +205,14 @@ public static ProtoDomain buildFrom(InputStream inputStream) throws IOException private void crosswire() { HashMap map = new HashMap<>(); - fileDescriptorSet.getFileList().forEach(fdp -> map.put(fdp.getName(), fdp)); + fileDescriptorSet.getFileList().stream() + .filter(fdp -> !fdp.getName().startsWith("google/protobuf")) + .forEach(fdp -> map.put(fdp.getName(), fdp)); + ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); Map outMap = new HashMap<>(); - map.forEach((fileName, proto) -> convertToFileDescriptorMap(fileName, map, outMap)); + map.forEach( + (fileName, proto) -> convertToFileDescriptorMap(fileName, map, outMap, extensionRegistry)); fileDescriptorMap = outMap; indexOptionsByNumber(fileDescriptorMap.values()); diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java index 31856583058f..5a993eecef90 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchema.java @@ -17,12 +17,10 @@ */ package org.apache.beam.sdk.extensions.protobuf; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.SCHEMA_OPTION_META_NUMBER; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.SCHEMA_OPTION_META_TYPE_NAME; import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getFieldNumber; -import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMapKeyMessageName; -import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMapValueMessageName; -import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMessageName; import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber; -import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withMessageName; import com.google.protobuf.ByteString; import com.google.protobuf.Descriptors; @@ -110,28 +108,34 @@ private Object readResolve() { Convert createConverter(Schema.Field field) { Schema.FieldType fieldType = field.getType(); - String messageName = getMessageName(fieldType); - if (messageName != null && messageName.length() > 0) { + if (fieldType.getNullable()) { Schema.Field valueField = - Schema.Field.of("value", withFieldNumber(Schema.FieldType.BOOLEAN, 1)); - switch (messageName) { - case "google.protobuf.StringValue": - case "google.protobuf.DoubleValue": - case "google.protobuf.FloatValue": - case "google.protobuf.BoolValue": - case "google.protobuf.Int64Value": - case "google.protobuf.Int32Value": - case "google.protobuf.UInt64Value": - case "google.protobuf.UInt32Value": + withFieldNumber(Schema.Field.of("value", Schema.FieldType.BOOLEAN), 1); + switch (fieldType.getTypeName()) { + case BYTE: + case INT16: + case INT32: + case INT64: + case FLOAT: + case DOUBLE: + case STRING: + case BOOLEAN: return new WrapperConvert(field, new PrimitiveConvert(valueField)); - case "google.protobuf.BytesValue": + case BYTES: return new WrapperConvert(field, new BytesConvert(valueField)); - case "google.protobuf.Timestamp": - case "google.protobuf.Duration": - // handled by logical type case - break; + case LOGICAL_TYPE: + String identifier = field.getType().getLogicalType().getIdentifier(); + switch (identifier) { + case ProtoSchemaLogicalTypes.UInt32.IDENTIFIER: + case ProtoSchemaLogicalTypes.UInt64.IDENTIFIER: + return new WrapperConvert(field, new PrimitiveConvert(valueField)); + default: + } + // fall through + default: } } + switch (fieldType.getTypeName()) { case BYTE: case INT16: @@ -260,7 +264,8 @@ public DynamicMessage.Builder invokeNewBuilder() { @Override public Context getSubContext(Schema.Field field) { - String messageName = getMessageName(field.getType()); + String messageName = + field.getType().getRowSchema().getOptions().getValue(SCHEMA_OPTION_META_TYPE_NAME); return new DescriptorContext(messageName, domain); } } @@ -274,9 +279,10 @@ abstract static class Convert { private int number; Convert(Schema.Field field) { - try { - this.number = getFieldNumber(field.getType()); - } catch (NumberFormatException e) { + Schema.Options options = field.getOptions(); + if (options.hasOption(SCHEMA_OPTION_META_NUMBER)) { + this.number = options.getValue(SCHEMA_OPTION_META_NUMBER); + } else { this.number = -1; } } @@ -546,16 +552,8 @@ static class MapConvert extends Convert { MapConvert(ProtoDynamicMessageSchema protoSchema, Schema.Field field) { super(field); Schema.FieldType fieldType = field.getType(); - key = - protoSchema.createConverter( - Schema.Field.of( - "KEY", - withMessageName(fieldType.getMapKeyType(), getMapKeyMessageName(fieldType)))); - value = - protoSchema.createConverter( - Schema.Field.of( - "VALUE", - withMessageName(fieldType.getMapValueType(), getMapValueMessageName(fieldType)))); + key = protoSchema.createConverter(Schema.Field.of("KEY", fieldType.getMapKeyType())); + value = protoSchema.createConverter(Schema.Field.of("VALUE", fieldType.getMapValueType())); } @Override @@ -617,11 +615,7 @@ static class ArrayConvert extends Convert { ArrayConvert(ProtoDynamicMessageSchema protoSchema, Schema.Field field) { super(field); Schema.FieldType collectionElementType = field.getType().getCollectionElementType(); - this.element = - protoSchema.createConverter( - Schema.Field.of( - "ELEMENT", - withMessageName(collectionElementType, getMessageName(field.getType())))); + this.element = protoSchema.createConverter(Schema.Field.of("ELEMENT", collectionElementType)); } @Override @@ -703,9 +697,11 @@ static class OneOfConvert extends Convert { super(field); this.logicalType = (OneOfType) logicalType; for (Schema.Field oneOfField : this.logicalType.getOneOfSchema().getFields()) { - int fieldNumber = getFieldNumber(oneOfField.getType()); + int fieldNumber = getFieldNumber(oneOfField); oneOfConvert.put( - fieldNumber, new NullableConvert(oneOfField, protoSchema.createConverter(oneOfField))); + fieldNumber, + new NullableConvert( + oneOfField, protoSchema.createConverter(oneOfField.withNullable(false)))); } } diff --git a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java index 29810905e43e..dd43096943b4 100644 --- a/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java +++ b/sdks/java/extensions/protobuf/src/main/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslator.java @@ -126,54 +126,25 @@ * */ class ProtoSchemaTranslator { - /** This METADATA tag is used to store the field number of a proto tag. */ - public static final String PROTO_NUMBER_METADATA_TAG = "PROTO_NUMBER"; + public static final String SCHEMA_OPTION_META_NUMBER = "beam:option:proto:meta:number"; - public static final String PROTO_MESSAGE_NAME_METADATA_TAG = "PROTO_MESSAGE_NAME"; + public static final String SCHEMA_OPTION_META_TYPE_NAME = "beam:option:proto:meta:type_name"; - public static final String PROTO_MAP_KEY_MESSAGE_NAME_METADATA_TAG = "PROTO_MAP_KEY_MESSAGE_NAME"; + /** Option prefix for options on messages. */ + public static final String SCHEMA_OPTION_MESSAGE_PREFIX = "beam:option:proto:message:"; - public static final String PROTO_MAP_VALUE_MESSAGE_NAME_METADATA_TAG = - "PROTO_MAP_VALUE_MESSAGE_NAME"; + /** Option prefix for options on fields. */ + public static final String SCHEMA_OPTION_FIELD_PREFIX = "beam:option:proto:field:"; /** Attach a proto field number to a type. */ - static FieldType withFieldNumber(FieldType fieldType, int index) { - return fieldType.withMetadata(PROTO_NUMBER_METADATA_TAG, Long.toString(index)); + static Field withFieldNumber(Field field, int number) { + return field.withOptions( + Schema.Options.builder().setOption(SCHEMA_OPTION_META_NUMBER, FieldType.INT32, number)); } /** Return the proto field number for a type. */ - static int getFieldNumber(FieldType fieldType) { - return Integer.parseInt(fieldType.getMetadataString(PROTO_NUMBER_METADATA_TAG)); - } - - /** Attach the name of the message to a type. */ - public static FieldType withMessageName(FieldType fieldType, String messageName) { - return fieldType.withMetadata(PROTO_MESSAGE_NAME_METADATA_TAG, messageName); - } - - /** Return the message name for a type. */ - public static String getMessageName(FieldType fieldType) { - return fieldType.getMetadataString(PROTO_MESSAGE_NAME_METADATA_TAG); - } - - /** Attach the name of the message to a map key. */ - public static FieldType withMapKeyMessageName(FieldType fieldType, String messageName) { - return fieldType.withMetadata(PROTO_MAP_KEY_MESSAGE_NAME_METADATA_TAG, messageName); - } - - /** Return the message name for a map key. */ - public static String getMapKeyMessageName(FieldType fieldType) { - return fieldType.getMetadataString(PROTO_MAP_KEY_MESSAGE_NAME_METADATA_TAG); - } - - /** Attach the name of the message to a map value. */ - public static FieldType withMapValueMessageName(FieldType fieldType, String messageName) { - return fieldType.withMetadata(PROTO_MAP_VALUE_MESSAGE_NAME_METADATA_TAG, messageName); - } - - /** Return the message name for a map value. */ - public static String getMapValueMessageName(FieldType fieldType) { - return fieldType.getMetadataString(PROTO_MAP_VALUE_MESSAGE_NAME_METADATA_TAG); + static int getFieldNumber(Field field) { + return field.getOptions().getValue(SCHEMA_OPTION_META_NUMBER); } /** Return a Beam scheam representing a proto class. */ @@ -189,10 +160,11 @@ static Schema getSchema(Descriptors.Descriptor descriptor) { Map enumIds = Maps.newHashMap(); for (FieldDescriptor fieldDescriptor : oneofDescriptor.getFields()) { oneOfFields.add(fieldDescriptor.getNumber()); - // Store proto field number in metadata. - FieldType fieldType = - withMetaData(beamFieldTypeFromProtoField(fieldDescriptor), fieldDescriptor); - subFields.add(Field.nullable(fieldDescriptor.getName(), fieldType)); + // Store proto field number in a field option. + FieldType fieldType = beamFieldTypeFromProtoField(fieldDescriptor); + subFields.add( + withFieldNumber( + Field.nullable(fieldDescriptor.getName(), fieldType), fieldDescriptor.getNumber())); checkArgument( enumIds.putIfAbsent(fieldDescriptor.getName(), fieldDescriptor.getNumber()) == null); } @@ -203,33 +175,20 @@ static Schema getSchema(Descriptors.Descriptor descriptor) { for (Descriptors.FieldDescriptor fieldDescriptor : descriptor.getFields()) { if (!oneOfFields.contains(fieldDescriptor.getNumber())) { // Store proto field number in metadata. - FieldType fieldType = - withMetaData(beamFieldTypeFromProtoField(fieldDescriptor), fieldDescriptor); - fields.add(Field.of(fieldDescriptor.getName(), fieldType)); + FieldType fieldType = beamFieldTypeFromProtoField(fieldDescriptor); + fields.add( + withFieldNumber( + Field.of(fieldDescriptor.getName(), fieldType), fieldDescriptor.getNumber()) + .withOptions(getFieldOptions(fieldDescriptor))); } } - return Schema.builder().addFields(fields).build(); - } - - private static FieldType withMetaData( - FieldType inType, Descriptors.FieldDescriptor fieldDescriptor) { - FieldType fieldType = withFieldNumber(inType, fieldDescriptor.getNumber()); - if (fieldDescriptor.isMapField()) { - FieldDescriptor keyFieldDescriptor = fieldDescriptor.getMessageType().findFieldByName("key"); - FieldDescriptor valueFieldDescriptor = - fieldDescriptor.getMessageType().findFieldByName("value"); - if ((keyFieldDescriptor.getType() == FieldDescriptor.Type.MESSAGE)) { - fieldType = - withMapKeyMessageName(fieldType, keyFieldDescriptor.getMessageType().getFullName()); - } - if ((valueFieldDescriptor.getType() == FieldDescriptor.Type.MESSAGE)) { - fieldType = - withMapValueMessageName(fieldType, valueFieldDescriptor.getMessageType().getFullName()); - } - } else if (fieldDescriptor.getType() == FieldDescriptor.Type.MESSAGE) { - return withMessageName(fieldType, fieldDescriptor.getMessageType().getFullName()); - } - return fieldType; + return Schema.builder() + .addFields(fields) + .setOptions( + getSchemaOptions(descriptor) + .setOption( + SCHEMA_OPTION_META_TYPE_NAME, FieldType.STRING, descriptor.getFullName())) + .build(); } private static FieldType beamFieldTypeFromProtoField( @@ -352,4 +311,50 @@ private static FieldType beamFieldTypeFromSingularProtoField( } return fieldType; } + + private static Schema.Options.Builder getFieldOptions(FieldDescriptor fieldDescriptor) { + return getOptions(SCHEMA_OPTION_FIELD_PREFIX, fieldDescriptor.getOptions().getAllFields()); + } + + private static Schema.Options.Builder getSchemaOptions(Descriptors.Descriptor descriptor) { + return getOptions(SCHEMA_OPTION_MESSAGE_PREFIX, descriptor.getOptions().getAllFields()); + } + + private static Schema.Options.Builder getOptions( + String prefix, Map allFields) { + Schema.Options.Builder optionsBuilder = Schema.Options.builder(); + for (Map.Entry entry : allFields.entrySet()) { + FieldDescriptor fieldDescriptor = entry.getKey(); + FieldType fieldType = beamFieldTypeFromProtoField(fieldDescriptor); + + switch (fieldType.getTypeName()) { + case BYTE: + case BYTES: + case INT16: + case INT32: + case INT64: + case DECIMAL: + case FLOAT: + case DOUBLE: + case STRING: + case BOOLEAN: + case LOGICAL_TYPE: + case ROW: + case ARRAY: + case ITERABLE: + Field field = Field.of("OPTION", fieldType); + ProtoDynamicMessageSchema schema = ProtoDynamicMessageSchema.forSchema(Schema.of(field)); + optionsBuilder.setOption( + prefix + fieldDescriptor.getFullName(), + fieldType, + schema.createConverter(field).convertFromProtoValue(entry.getValue())); + break; + case MAP: + case DATETIME: + default: + throw new IllegalStateException("These datatypes are not possible in extentions."); + } + } + return optionsBuilder; + } } diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomainTest.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomainTest.java index 5ff909bdf7ab..31a652ff2b35 100644 --- a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomainTest.java +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDomainTest.java @@ -19,6 +19,8 @@ import com.google.protobuf.Int32Value; import com.google.protobuf.Int64Value; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.Nested; +import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.Primitive; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -37,19 +39,19 @@ public void testNamespaceEqual() { @Test public void testContainsDescriptor() { - ProtoDomain domainFromInt32 = ProtoDomain.buildFrom(Int32Value.getDescriptor()); - Assert.assertTrue(domainFromInt32.contains(Int32Value.getDescriptor())); + ProtoDomain domainFromInt32 = ProtoDomain.buildFrom(Primitive.getDescriptor()); + Assert.assertTrue(domainFromInt32.contains(Primitive.getDescriptor())); } @Test public void testContainsOtherDescriptorSameFile() { - ProtoDomain domain = ProtoDomain.buildFrom(Int32Value.getDescriptor()); - Assert.assertTrue(domain.contains(Int64Value.getDescriptor())); + ProtoDomain domain = ProtoDomain.buildFrom(Primitive.getDescriptor()); + Assert.assertTrue(domain.contains(Nested.getDescriptor())); } @Test public void testBuildForFile() { - ProtoDomain domain = ProtoDomain.buildFrom(Int32Value.getDescriptor().getFile()); + ProtoDomain domain = ProtoDomain.buildFrom(Primitive.getDescriptor().getFile()); Assert.assertNotNull(domain.getFileDescriptor("google/protobuf/wrappers.proto")); } } diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchemaTest.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchemaTest.java index 0b7984d3c32e..17b5106f2626 100644 --- a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchemaTest.java +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoDynamicMessageSchemaTest.java @@ -17,7 +17,6 @@ */ package org.apache.beam.sdk.extensions.protobuf; -import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber; import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_PROTO; import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_ROW; import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_SCHEMA; @@ -49,6 +48,8 @@ import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_PROTO; import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_ROW; import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.withFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.withTypeName; import static org.junit.Assert.assertEquals; import com.google.protobuf.Descriptors; @@ -265,8 +266,9 @@ public void testOuterOneOfRowToProto() { private static final Schema ENUM_SCHEMA = Schema.builder() .addField( - "enum", - withFieldNumber(Schema.FieldType.logicalType(ENUM_TYPE).withNullable(false), 1)) + withFieldNumber("enum", Schema.FieldType.logicalType(ENUM_TYPE), 1) + .withNullable(false)) + .setOptions(withTypeName("proto3_schema_messages.EnumMessage")) .build(); private static final Row ENUM_ROW = Row.withSchema(ENUM_SCHEMA).addValues(ENUM_TYPE.valueOf("TWO")).build(); diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchemaTest.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchemaTest.java index c8b182c75fcb..f026ab37babf 100644 --- a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchemaTest.java +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoMessageSchemaTest.java @@ -17,7 +17,6 @@ */ package org.apache.beam.sdk.extensions.protobuf; -import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber; import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_PROTO; import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_ROW; import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.MAP_PRIMITIVE_SCHEMA; @@ -55,6 +54,8 @@ import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_PROTO; import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_ROW; import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_SCHEMA; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.withFieldNumber; +import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.withTypeName; import static org.junit.Assert.assertEquals; import org.apache.beam.sdk.extensions.protobuf.Proto2SchemaMessages.OptionalPrimitive; @@ -281,7 +282,8 @@ public void testOuterOneOfRowToProto() { EnumerationType.create(ImmutableMap.of("ZERO", 0, "TWO", 2, "THREE", 3)); private static final Schema ENUM_SCHEMA = Schema.builder() - .addField("enum", withFieldNumber(FieldType.logicalType(ENUM_TYPE), 1)) + .addField(withFieldNumber("enum", FieldType.logicalType(ENUM_TYPE), 1)) + .setOptions(withTypeName("proto3_schema_messages.EnumMessage")) .build(); private static final Row ENUM_ROW = Row.withSchema(ENUM_SCHEMA).addValues(ENUM_TYPE.valueOf("TWO")).build(); diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslatorTest.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslatorTest.java index 5ff6dc00d291..4e3cf9a59afc 100644 --- a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslatorTest.java +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/ProtoSchemaTranslatorTest.java @@ -19,6 +19,11 @@ import static org.junit.Assert.assertEquals; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.values.Row; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -102,4 +107,125 @@ public void testRequiredNestedSchema() { TestProtoSchemas.REQUIRED_NESTED_SCHEMA, ProtoSchemaTranslator.getSchema(Proto2SchemaMessages.RequiredNested.class)); } + + @Test + public void testOptionsInt32OnMessage() { + Schema schema = ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.OptionMessage.class); + assertEquals( + Integer.valueOf(42), + schema + .getOptions() + .getValue("beam:option:proto:message:proto3_schema_options.message_option_int")); + } + + @Test + public void testOptionsStringOnMessage() { + Schema schema = ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.OptionMessage.class); + assertEquals( + "this is a message string", + schema + .getOptions() + .getValue("beam:option:proto:message:proto3_schema_options.message_option_string")); + } + + @Test + public void testOptionsMessageOnMessage() { + Schema schema = ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.OptionMessage.class); + Row optionMessage = + schema + .getOptions() + .getValue("beam:option:proto:message:proto3_schema_options.message_option_message"); + assertEquals("foobar in message", optionMessage.getString("single_string")); + assertEquals(Integer.valueOf(12), optionMessage.getInt32("single_int32")); + assertEquals(Long.valueOf(34L), optionMessage.getInt64("single_int64")); + } + + @Test + public void testOptionArrayOnMessage() { + Schema schema = ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.OptionMessage.class); + List verify = + new ArrayList( + (Collection) + schema + .getOptions() + .getValue( + "beam:option:proto:message:proto3_schema_options.message_option_repeated")); + assertEquals("string_1", verify.get(0)); + assertEquals("string_2", verify.get(1)); + assertEquals("string_3", verify.get(2)); + } + + @Test + public void testOptionArrayOfMessagesOnMessage() { + Schema schema = ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.OptionMessage.class); + List verify = + new ArrayList( + (Collection) + schema + .getOptions() + .getValue( + "beam:option:proto:message:proto3_schema_options.message_option_repeated_message")); + assertEquals( + "string in message in option in message", ((Row) verify.get(0)).getString("single_string")); + assertEquals(Integer.valueOf(1), ((Row) verify.get(1)).getInt32("single_int32")); + assertEquals(Long.valueOf(2L), ((Row) verify.get(2)).getInt64("single_int64")); + } + + @Test + public void testOptionsInt32OnField() { + Schema schema = ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.OptionMessage.class); + Schema.Options options = schema.getField("field_one").getOptions(); + assertEquals( + Integer.valueOf(13), + options.getValue("beam:option:proto:field:proto3_schema_options.field_option_int")); + } + + @Test + public void testOptionsStringOnField() { + Schema schema = ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.OptionMessage.class); + Schema.Options options = schema.getField("field_one").getOptions(); + assertEquals( + "this is a field string", + options.getValue("beam:option:proto:field:proto3_schema_options.field_option_string")); + } + + @Test + public void testOptionsMessageOnField() { + Schema schema = ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.OptionMessage.class); + Schema.Options options = schema.getField("field_one").getOptions(); + Row optionMessage = + options.getValue("beam:option:proto:field:proto3_schema_options.field_option_message"); + assertEquals("foobar in field", optionMessage.getString("single_string")); + assertEquals(Integer.valueOf(56), optionMessage.getInt32("single_int32")); + assertEquals(Long.valueOf(78L), optionMessage.getInt64("single_int64")); + } + + @Test + public void testOptionArrayOnField() { + Schema schema = ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.OptionMessage.class); + Schema.Options options = schema.getField("field_one").getOptions(); + List verify = + new ArrayList( + (Collection) + options.getValue( + "beam:option:proto:field:proto3_schema_options.field_option_repeated")); + assertEquals("field_string_1", verify.get(0)); + assertEquals("field_string_2", verify.get(1)); + assertEquals("field_string_3", verify.get(2)); + } + + @Test + public void testOptionArrayOfMessagesOnField() { + Schema schema = ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.OptionMessage.class); + Schema.Options options = schema.getField("field_one").getOptions(); + List verify = + new ArrayList( + (Collection) + options.getValue( + "beam:option:proto:field:proto3_schema_options.field_option_repeated_message")); + assertEquals( + "string in message in option in field", ((Row) verify.get(0)).getString("single_string")); + assertEquals(Integer.valueOf(77), ((Row) verify.get(1)).getInt32("single_int32")); + assertEquals(Long.valueOf(88L), ((Row) verify.get(2)).getInt64("single_int64")); + } } diff --git a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java index 6367659e15dd..d7c9b8bcfea9 100644 --- a/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java +++ b/sdks/java/extensions/protobuf/src/test/java/org/apache/beam/sdk/extensions/protobuf/TestProtoSchemas.java @@ -17,10 +17,9 @@ */ package org.apache.beam.sdk.extensions.protobuf; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.SCHEMA_OPTION_META_NUMBER; +import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.SCHEMA_OPTION_META_TYPE_NAME; import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getFieldNumber; -import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber; -import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withMapValueMessageName; -import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withMessageName; import com.google.protobuf.BoolValue; import com.google.protobuf.ByteString; @@ -68,42 +67,74 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; class TestProtoSchemas { + + static Field withFieldNumber(String name, FieldType fieldType, int fieldNumber) { + return Field.of(name, fieldType) + .withOptions( + Schema.Options.builder() + .setOption(SCHEMA_OPTION_META_NUMBER, FieldType.INT32, fieldNumber)); + } + + static Schema.Options withTypeName(String typeName) { + return Schema.Options.builder() + .setOption(SCHEMA_OPTION_META_TYPE_NAME, FieldType.STRING, typeName) + .build(); + } + // The schema we expect from the Primitive proto. static final Schema PRIMITIVE_SCHEMA = Schema.builder() - .addField("primitive_double", withFieldNumber(FieldType.DOUBLE, 1)) - .addField("primitive_float", withFieldNumber(FieldType.FLOAT, 2)) - .addField("primitive_int32", withFieldNumber(FieldType.INT32, 3)) - .addField("primitive_int64", withFieldNumber(FieldType.INT64, 4)) - .addField("primitive_uint32", withFieldNumber(FieldType.logicalType(new UInt32()), 5)) - .addField("primitive_uint64", withFieldNumber(FieldType.logicalType(new UInt64()), 6)) - .addField("primitive_sint32", withFieldNumber(FieldType.logicalType(new SInt32()), 7)) - .addField("primitive_sint64", withFieldNumber(FieldType.logicalType(new SInt64()), 8)) - .addField("primitive_fixed32", withFieldNumber(FieldType.logicalType(new Fixed32()), 9)) - .addField("primitive_fixed64", withFieldNumber(FieldType.logicalType(new Fixed64()), 10)) + .addField(withFieldNumber("primitive_double", FieldType.DOUBLE, 1)) + .addField(withFieldNumber("primitive_float", FieldType.FLOAT, 2)) + .addField(withFieldNumber("primitive_int32", FieldType.INT32, 3)) + .addField(withFieldNumber("primitive_int64", FieldType.INT64, 4)) + .addField(withFieldNumber("primitive_uint32", FieldType.logicalType(new UInt32()), 5)) + .addField(withFieldNumber("primitive_uint64", FieldType.logicalType(new UInt64()), 6)) + .addField(withFieldNumber("primitive_sint32", FieldType.logicalType(new SInt32()), 7)) + .addField(withFieldNumber("primitive_sint64", FieldType.logicalType(new SInt64()), 8)) + .addField(withFieldNumber("primitive_fixed32", FieldType.logicalType(new Fixed32()), 9)) + .addField(withFieldNumber("primitive_fixed64", FieldType.logicalType(new Fixed64()), 10)) .addField( - "primitive_sfixed32", withFieldNumber(FieldType.logicalType(new SFixed32()), 11)) + withFieldNumber("primitive_sfixed32", FieldType.logicalType(new SFixed32()), 11)) .addField( - "primitive_sfixed64", withFieldNumber(FieldType.logicalType(new SFixed64()), 12)) - .addField("primitive_bool", withFieldNumber(FieldType.BOOLEAN, 13)) - .addField("primitive_string", withFieldNumber(FieldType.STRING, 14)) - .addField("primitive_bytes", withFieldNumber(FieldType.BYTES, 15)) + withFieldNumber("primitive_sfixed64", FieldType.logicalType(new SFixed64()), 12)) + .addField(withFieldNumber("primitive_bool", FieldType.BOOLEAN, 13)) + .addField(withFieldNumber("primitive_string", FieldType.STRING, 14)) + .addField(withFieldNumber("primitive_bytes", FieldType.BYTES, 15)) + .setOptions( + Schema.Options.builder() + .setOption( + SCHEMA_OPTION_META_TYPE_NAME, + FieldType.STRING, + "proto3_schema_messages.Primitive")) .build(); static final Schema OPTIONAL_PRIMITIVE_SCHEMA = Schema.builder() - .addField("primitive_int32", withFieldNumber(FieldType.INT32, 1)) - .addField("primitive_bool", withFieldNumber(FieldType.BOOLEAN, 2)) - .addField("primitive_string", withFieldNumber(FieldType.STRING, 3)) - .addField("primitive_bytes", withFieldNumber(FieldType.BYTES, 4)) + .addField(withFieldNumber("primitive_int32", FieldType.INT32, 1)) + .addField(withFieldNumber("primitive_bool", FieldType.BOOLEAN, 2)) + .addField(withFieldNumber("primitive_string", FieldType.STRING, 3)) + .addField(withFieldNumber("primitive_bytes", FieldType.BYTES, 4)) + .setOptions( + Schema.Options.builder() + .setOption( + SCHEMA_OPTION_META_TYPE_NAME, + FieldType.STRING, + "proto2_schema_messages.OptionalPrimitive")) .build(); static final Schema REQUIRED_PRIMITIVE_SCHEMA = Schema.builder() - .addField("primitive_int32", withFieldNumber(FieldType.INT32, 1)) - .addField("primitive_bool", withFieldNumber(FieldType.BOOLEAN, 2)) - .addField("primitive_string", withFieldNumber(FieldType.STRING, 3)) - .addField("primitive_bytes", withFieldNumber(FieldType.BYTES, 4)) + .addField(withFieldNumber("primitive_int32", FieldType.INT32, 1)) + .addField(withFieldNumber("primitive_bool", FieldType.BOOLEAN, 2)) + .addField(withFieldNumber("primitive_string", FieldType.STRING, 3)) + .addField(withFieldNumber("primitive_bytes", FieldType.BYTES, 4)) + .setOptions( + Schema.Options.builder() + .setOption( + SCHEMA_OPTION_META_TYPE_NAME, + FieldType.STRING, + "proto2_schema_messages.RequiredPrimitive")) .build(); // A sample instance of the row. @@ -163,37 +194,38 @@ class TestProtoSchemas { // The schema for the RepeatedPrimitive proto. static final Schema REPEATED_SCHEMA = Schema.builder() - .addField("repeated_double", withFieldNumber(FieldType.array(FieldType.DOUBLE), 1)) - .addField("repeated_float", withFieldNumber(FieldType.array(FieldType.FLOAT), 2)) - .addField("repeated_int32", withFieldNumber(FieldType.array(FieldType.INT32), 3)) - .addField("repeated_int64", withFieldNumber(FieldType.array(FieldType.INT64), 4)) + .addField(withFieldNumber("repeated_double", FieldType.array(FieldType.DOUBLE), 1)) + .addField(withFieldNumber("repeated_float", FieldType.array(FieldType.FLOAT), 2)) + .addField(withFieldNumber("repeated_int32", FieldType.array(FieldType.INT32), 3)) + .addField(withFieldNumber("repeated_int64", FieldType.array(FieldType.INT64), 4)) .addField( - "repeated_uint32", - withFieldNumber(FieldType.array(FieldType.logicalType(new UInt32())), 5)) + withFieldNumber( + "repeated_uint32", FieldType.array(FieldType.logicalType(new UInt32())), 5)) .addField( - "repeated_uint64", - withFieldNumber(FieldType.array(FieldType.logicalType(new UInt64())), 6)) + withFieldNumber( + "repeated_uint64", FieldType.array(FieldType.logicalType(new UInt64())), 6)) .addField( - "repeated_sint32", - withFieldNumber(FieldType.array(FieldType.logicalType(new SInt32())), 7)) + withFieldNumber( + "repeated_sint32", FieldType.array(FieldType.logicalType(new SInt32())), 7)) .addField( - "repeated_sint64", - withFieldNumber(FieldType.array(FieldType.logicalType(new SInt64())), 8)) + withFieldNumber( + "repeated_sint64", FieldType.array(FieldType.logicalType(new SInt64())), 8)) .addField( - "repeated_fixed32", - withFieldNumber(FieldType.array(FieldType.logicalType(new Fixed32())), 9)) + withFieldNumber( + "repeated_fixed32", FieldType.array(FieldType.logicalType(new Fixed32())), 9)) .addField( - "repeated_fixed64", - withFieldNumber(FieldType.array(FieldType.logicalType(new Fixed64())), 10)) + withFieldNumber( + "repeated_fixed64", FieldType.array(FieldType.logicalType(new Fixed64())), 10)) .addField( - "repeated_sfixed32", - withFieldNumber(FieldType.array(FieldType.logicalType(new SFixed32())), 11)) + withFieldNumber( + "repeated_sfixed32", FieldType.array(FieldType.logicalType(new SFixed32())), 11)) .addField( - "repeated_sfixed64", - withFieldNumber(FieldType.array(FieldType.logicalType(new SFixed64())), 12)) - .addField("repeated_bool", withFieldNumber(FieldType.array(FieldType.BOOLEAN), 13)) - .addField("repeated_string", withFieldNumber(FieldType.array(FieldType.STRING), 14)) - .addField("repeated_bytes", withFieldNumber(FieldType.array(FieldType.BYTES), 15)) + withFieldNumber( + "repeated_sfixed64", FieldType.array(FieldType.logicalType(new SFixed64())), 12)) + .addField(withFieldNumber("repeated_bool", FieldType.array(FieldType.BOOLEAN), 13)) + .addField(withFieldNumber("repeated_string", FieldType.array(FieldType.STRING), 14)) + .addField(withFieldNumber("repeated_bytes", FieldType.array(FieldType.BYTES), 15)) + .setOptions(withTypeName("proto3_schema_messages.RepeatPrimitive")) .build(); // A sample instance of the row. @@ -263,17 +295,18 @@ class TestProtoSchemas { static final Schema MAP_PRIMITIVE_SCHEMA = Schema.builder() .addField( - "string_string_map", - withFieldNumber(FieldType.map(FieldType.STRING, FieldType.STRING), 1)) + withFieldNumber( + "string_string_map", FieldType.map(FieldType.STRING, FieldType.STRING), 1)) .addField( - "string_int_map", - withFieldNumber(FieldType.map(FieldType.STRING, FieldType.INT32), 2)) + withFieldNumber( + "string_int_map", FieldType.map(FieldType.STRING, FieldType.INT32), 2)) .addField( - "int_string_map", - withFieldNumber(FieldType.map(FieldType.INT32, FieldType.STRING), 3)) + withFieldNumber( + "int_string_map", FieldType.map(FieldType.INT32, FieldType.STRING), 3)) .addField( - "string_bytes_map", - withFieldNumber(FieldType.map(FieldType.STRING, FieldType.BYTES), 4)) + withFieldNumber( + "string_bytes_map", FieldType.map(FieldType.STRING, FieldType.BYTES), 4)) + .setOptions(withTypeName("proto3_schema_messages.MapPrimitive")) .build(); // A sample instance of the row. @@ -312,21 +345,15 @@ class TestProtoSchemas { static final Schema NESTED_SCHEMA = Schema.builder() .addField( - "nested", - withMessageName( - withFieldNumber(FieldType.row(PRIMITIVE_SCHEMA).withNullable(true), 1), - "proto3_schema_messages.Primitive")) + withFieldNumber("nested", FieldType.row(PRIMITIVE_SCHEMA).withNullable(true), 1)) .addField( - "nested_list", - withMessageName( - withFieldNumber(FieldType.array(FieldType.row(PRIMITIVE_SCHEMA)), 2), - "proto3_schema_messages.Primitive")) + withFieldNumber("nested_list", FieldType.array(FieldType.row(PRIMITIVE_SCHEMA)), 2)) .addField( - "nested_map", - withMapValueMessageName( - withFieldNumber( - FieldType.map(FieldType.STRING, FieldType.row(PRIMITIVE_SCHEMA)), 3), - "proto3_schema_messages.Primitive")) + withFieldNumber( + "nested_map", + FieldType.map(FieldType.STRING, FieldType.row(PRIMITIVE_SCHEMA)), + 3)) + .setOptions(withTypeName("proto3_schema_messages.Nested")) .build(); // A sample instance of the row. @@ -348,23 +375,19 @@ class TestProtoSchemas { // The schema for the OneOf proto. private static final List ONEOF_FIELDS = ImmutableList.of( - Field.of("oneof_int32", withFieldNumber(FieldType.INT32, 2)), - Field.of("oneof_bool", withFieldNumber(FieldType.BOOLEAN, 3)), - Field.of("oneof_string", withFieldNumber(FieldType.STRING, 4)), - Field.of( - "oneof_primitive", - withMessageName( - withFieldNumber(FieldType.row(PRIMITIVE_SCHEMA), 5), - "proto3_schema_messages.Primitive"))); + withFieldNumber("oneof_int32", FieldType.INT32, 2), + withFieldNumber("oneof_bool", FieldType.BOOLEAN, 3), + withFieldNumber("oneof_string", FieldType.STRING, 4), + withFieldNumber("oneof_primitive", FieldType.row(PRIMITIVE_SCHEMA), 5)); private static final Map ONE_OF_ENUM_MAP = - ONEOF_FIELDS.stream() - .collect(Collectors.toMap(Field::getName, f -> getFieldNumber(f.getType()))); + ONEOF_FIELDS.stream().collect(Collectors.toMap(Field::getName, f -> getFieldNumber(f))); static final OneOfType ONE_OF_TYPE = OneOfType.create(ONEOF_FIELDS, ONE_OF_ENUM_MAP); static final Schema ONEOF_SCHEMA = Schema.builder() .addField("special_oneof", FieldType.logicalType(ONE_OF_TYPE)) - .addField("place1", withFieldNumber(FieldType.STRING, 1)) - .addField("place2", withFieldNumber(FieldType.INT32, 6)) + .addField(withFieldNumber("place1", FieldType.STRING, 1)) + .addField(withFieldNumber("place2", FieldType.INT32, 6)) + .setOptions(withTypeName("proto3_schema_messages.OneOf")) .build(); // Sample row instances for each OneOf case. @@ -398,18 +421,17 @@ class TestProtoSchemas { // The schema for the OuterOneOf proto. private static final List OUTER_ONEOF_FIELDS = ImmutableList.of( - Field.of( - "oneof_oneof", - withMessageName( - withFieldNumber(FieldType.row(ONEOF_SCHEMA), 1), "proto3_schema_messages.OneOf")), - Field.of("oneof_int32", withFieldNumber(FieldType.INT32, 2))); + withFieldNumber("oneof_oneof", FieldType.row(ONEOF_SCHEMA), 1), + withFieldNumber("oneof_int32", FieldType.INT32, 2)); private static final Map OUTER_ONE_OF_ENUM_MAP = - OUTER_ONEOF_FIELDS.stream() - .collect(Collectors.toMap(Field::getName, f -> getFieldNumber(f.getType()))); + OUTER_ONEOF_FIELDS.stream().collect(Collectors.toMap(Field::getName, f -> getFieldNumber(f))); static final OneOfType OUTER_ONEOF_TYPE = OneOfType.create(OUTER_ONEOF_FIELDS, OUTER_ONE_OF_ENUM_MAP); static final Schema OUTER_ONEOF_SCHEMA = - Schema.builder().addField("outer_oneof", FieldType.logicalType(OUTER_ONEOF_TYPE)).build(); + Schema.builder() + .addField("outer_oneof", FieldType.logicalType(OUTER_ONEOF_TYPE)) + .setOptions(withTypeName("proto3_schema_messages.OuterOneOf")) + .build(); // A sample instance of the Row. static final Row OUTER_ONEOF_ROW = @@ -423,47 +445,24 @@ class TestProtoSchemas { static final Schema WKT_MESSAGE_SCHEMA = Schema.builder() - .addNullableField( - "double", - withMessageName(withFieldNumber(FieldType.DOUBLE, 1), "google.protobuf.DoubleValue")) - .addNullableField( - "float", - withMessageName(withFieldNumber(FieldType.FLOAT, 2), "google.protobuf.FloatValue")) - .addNullableField( - "int32", - withMessageName(withFieldNumber(FieldType.INT32, 3), "google.protobuf.Int32Value")) - .addNullableField( - "int64", - withMessageName(withFieldNumber(FieldType.INT64, 4), "google.protobuf.Int64Value")) - .addNullableField( - "uint32", - withMessageName( - withFieldNumber(FieldType.logicalType(new UInt32()), 5), - "google.protobuf.UInt32Value")) - .addNullableField( - "uint64", - withMessageName( - withFieldNumber(FieldType.logicalType(new UInt64()), 6), - "google.protobuf.UInt64Value")) - .addNullableField( - "bool", - withMessageName(withFieldNumber(FieldType.BOOLEAN, 13), "google.protobuf.BoolValue")) - .addNullableField( - "string", - withMessageName(withFieldNumber(FieldType.STRING, 14), "google.protobuf.StringValue")) - .addNullableField( - "bytes", - withMessageName(withFieldNumber(FieldType.BYTES, 15), "google.protobuf.BytesValue")) - .addNullableField( - "timestamp", - withMessageName( - withFieldNumber(FieldType.logicalType(new NanosInstant()), 16), - "google.protobuf.Timestamp")) - .addNullableField( - "duration", - withMessageName( - withFieldNumber(FieldType.logicalType(new NanosDuration()), 17), - "google.protobuf.Duration")) + .addField(withFieldNumber("double", FieldType.DOUBLE, 1).withNullable(true)) + .addField(withFieldNumber("float", FieldType.FLOAT, 2).withNullable(true)) + .addField(withFieldNumber("int32", FieldType.INT32, 3).withNullable(true)) + .addField(withFieldNumber("int64", FieldType.INT64, 4).withNullable(true)) + .addField( + withFieldNumber("uint32", FieldType.logicalType(new UInt32()), 5).withNullable(true)) + .addField( + withFieldNumber("uint64", FieldType.logicalType(new UInt64()), 6).withNullable(true)) + .addField(withFieldNumber("bool", FieldType.BOOLEAN, 13).withNullable(true)) + .addField(withFieldNumber("string", FieldType.STRING, 14).withNullable(true)) + .addField(withFieldNumber("bytes", FieldType.BYTES, 15).withNullable(true)) + .addField( + withFieldNumber("timestamp", FieldType.logicalType(new NanosInstant()), 16) + .withNullable(true)) + .addField( + withFieldNumber("duration", FieldType.logicalType(new NanosDuration()), 17) + .withNullable(true)) + .setOptions(withTypeName("proto3_schema_messages.WktMessage")) .build(); // A sample instance of the row. static final Instant JAVA_NOW = Instant.now(); @@ -505,10 +504,9 @@ class TestProtoSchemas { static final Schema OPTIONAL_NESTED_SCHEMA = Schema.builder() .addField( - "nested", - withMessageName( - withFieldNumber(FieldType.row(OPTIONAL_PRIMITIVE_SCHEMA), 1).withNullable(true), - "proto2_schema_messages.OptionalPrimitive")) + withFieldNumber("nested", FieldType.row(OPTIONAL_PRIMITIVE_SCHEMA), 1) + .withNullable(true)) + .setOptions(withTypeName("proto2_schema_messages.OptionalNested")) .build(); // A sample instance of the proto. @@ -519,10 +517,9 @@ class TestProtoSchemas { static final Schema REQUIRED_NESTED_SCHEMA = Schema.builder() .addField( - "nested", - withMessageName( - withFieldNumber(FieldType.row(REQUIRED_PRIMITIVE_SCHEMA), 1).withNullable(false), - "proto2_schema_messages.RequiredPrimitive")) + withFieldNumber("nested", FieldType.row(REQUIRED_PRIMITIVE_SCHEMA), 1) + .withNullable(false)) + .setOptions(withTypeName("proto2_schema_messages.RequiredNested")) .build(); // A sample instance of the proto. diff --git a/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto b/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto index 6f6ec4442a30..a19ea043eeb2 100644 --- a/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto +++ b/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto @@ -29,6 +29,8 @@ import "google/protobuf/timestamp.proto"; import "google/protobuf/wrappers.proto"; import "google/protobuf/descriptor.proto"; +import "proto3_schema_options.proto"; + option java_package = "org.apache.beam.sdk.extensions.protobuf"; message Primitive { @@ -121,3 +123,47 @@ message WktMessage { google.protobuf.Duration duration = 17; } +message OptionMessage { + option (proto3_schema_options.message_option_int) = 42; + option (proto3_schema_options.message_option_string) = "this is a message string"; + option (proto3_schema_options.message_option_message) = { + single_string: "foobar in message" + single_int32: 12 + single_int64: 34 + }; + option (proto3_schema_options.message_option_repeated) = "string_1"; + option (proto3_schema_options.message_option_repeated) = "string_2"; + option (proto3_schema_options.message_option_repeated) = "string_3"; + option (proto3_schema_options.message_option_repeated) = "string_3"; + option (proto3_schema_options.message_option_repeated_message) = { + single_string: "string in message in option in message" + }; + option (proto3_schema_options.message_option_repeated_message) = { + single_int32: 1 + }; + option (proto3_schema_options.message_option_repeated_message) = { + single_int64: 2 + }; + + string field_one = 1 [ + (proto3_schema_options.field_option_int) = 13, + (proto3_schema_options.field_option_string) = "this is a field string", + (proto3_schema_options.field_option_message) = { + single_string: "foobar in field" + single_int32: 56 + single_int64: 78 + }, + (proto3_schema_options.field_option_repeated) = "field_string_1", + (proto3_schema_options.field_option_repeated) = "field_string_2", + (proto3_schema_options.field_option_repeated) = "field_string_3", + + (proto3_schema_options.field_option_repeated_message) = { + single_string: "string in message in option in field" + }, + (proto3_schema_options.field_option_repeated_message) = { + single_int32: 77 + }, + (proto3_schema_options.field_option_repeated_message) = { + single_int64: 88 + }]; +} \ No newline at end of file diff --git a/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_options.proto b/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_options.proto new file mode 100644 index 000000000000..453ae1167640 --- /dev/null +++ b/sdks/java/extensions/protobuf/src/test/proto/proto3_schema_options.proto @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Protocol Buffer options used for testing Proto3 Schema implementation. + */ + +syntax = "proto3"; + +package proto3_schema_options; + +import "google/protobuf/descriptor.proto"; + +option java_package = "org.apache.beam.sdk.extensions.protobuf"; + +extend google.protobuf.MessageOptions { + OptionTestMessage message_option_message = 66661700; + int32 message_option_int = 66661708; + string message_option_string = 66661709; + repeated string message_option_repeated = 66661710; + repeated OptionTestMessage message_option_repeated_message = 66661711; +} + +extend google.protobuf.FieldOptions { + OptionTestMessage field_option_message = 66662700; + int32 field_option_int = 66662708; + string field_option_string = 66662709; + repeated string field_option_repeated = 66662710; + repeated OptionTestMessage field_option_repeated_message = 66662711; +} + +extend google.protobuf.EnumOptions { + OptionTestMessage enum_option_message = 66665700; + int32 enum_option_int = 66665708; + string enum_option_string = 66665709; + repeated string enum_option_repeated = 66665710; + repeated OptionTestMessage enum_option_repeated_message = 66665711; +} + +extend google.protobuf.EnumValueOptions { + OptionTestMessage enum_value_option_message = 66666700; + int32 enum_value_option_int = 66666708; + string enum_value_option_string = 66666709; + repeated string enum_value_option_repeated = 66666710; + repeated OptionTestMessage enum_value_option_repeated_message = 66666711; +} + +extend google.protobuf.OneofOptions { + OptionTestMessage oneof_option_message = 66667700; + int32 oneof_option_int = 66667708; + string oneof_option_string = 66667709; + repeated string oneof_option_repeated = 66667710; + repeated OptionTestMessage oneof_option_repeated_message = 66667711; +} + +message OptionTestMessage { + message OptionTestSubMessage { + string sub_message_name = 1; + } + + string single_string = 1; + repeated string repeated_string = 2; + + int32 single_int32 = 3; + repeated int32 repeated_int32 = 4; + + int64 single_int64 = 5; + + bytes single_bytes = 6; + repeated bytes repeated_bytes = 7; + + enum OptionEnum { + ENUM1 = 0; + ENUM2 = 1; + } + OptionEnum single_enum = 8; + OptionTestSubMessage single_message = 9; +} \ No newline at end of file