Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ private OneOfType(List<Field> fields) {

private OneOfType(List<Field> fields, @Nullable Map<String, Integer> enumMap) {
List<Field> nullableFields =
fields.stream()
.map(f -> Field.nullable(f.getName(), f.getType()))
.collect(Collectors.toList());
fields.stream().map(f -> f.withNullable(true)).collect(Collectors.toList());
if (enumMap != null) {
nullableFields.stream().forEach(f -> checkArgument(enumMap.containsKey(f.getName())));
enumerationType = EnumerationType.create(enumMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ private static <ProtoT> FieldValueGetter createGetter(
fieldValueTypeSupplier.get(clazz, oneOfType.getOneOfSchema()).stream()
.collect(Collectors.toMap(FieldValueTypeInformation::getName, f -> f));
for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) {
int protoFieldIndex = getFieldNumber(oneOfField.getType());
int protoFieldIndex = getFieldNumber(oneOfField);
FieldValueGetter oneOfFieldGetter =
createGetter(
oneOfFieldTypes.get(oneOfField.getName()),
Expand Down Expand Up @@ -993,7 +993,7 @@ FieldValueSetter<ProtoBuilderT, Object> getProtoFieldValueSetter(
TreeMap<Integer, FieldValueSetter<ProtoBuilderT, Object>> oneOfSetters = Maps.newTreeMap();
for (Field oneOfField : oneOfType.getOneOfSchema().getFields()) {
FieldValueSetter setter = getProtoFieldValueSetter(oneOfField, methods, builderClass);
oneOfSetters.put(getFieldNumber(oneOfField.getType()), setter);
oneOfSetters.put(getFieldNumber(oneOfField), setter);
}
return createOneOfSetter(field.getName(), oneOfSetters, builderClass);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,20 @@
*/
package org.apache.beam.sdk.extensions.protobuf;

import com.google.protobuf.Any;
import com.google.protobuf.Api;
import com.google.protobuf.DescriptorProtos;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Duration;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.Empty;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.FieldMask;
import com.google.protobuf.Int32Value;
import com.google.protobuf.SourceContext;
import com.google.protobuf.Struct;
import com.google.protobuf.Timestamp;
import com.google.protobuf.Type;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
Expand Down Expand Up @@ -77,19 +89,53 @@ private static Map<String, DescriptorProtos.FileDescriptorProto> extractProtoMap
private static Descriptors.FileDescriptor convertToFileDescriptorMap(
String name,
Map<String, DescriptorProtos.FileDescriptorProto> inMap,
Map<String, Descriptors.FileDescriptor> outMap) {
Map<String, Descriptors.FileDescriptor> outMap,
ExtensionRegistry registry) {
if (outMap.containsKey(name)) {
return outMap.get(name);
}
DescriptorProtos.FileDescriptorProto fileDescriptorProto = inMap.get(name);
if (fileDescriptorProto == null) {
if ("google/protobuf/descriptor.proto".equals(name)) {
outMap.put(
"google/protobuf/descriptor.proto",
DescriptorProtos.FieldOptions.getDescriptor().getFile());
return DescriptorProtos.FieldOptions.getDescriptor().getFile();
Descriptors.FileDescriptor fd;
switch (name) {
case "google/protobuf/descriptor.proto":
fd = DescriptorProtos.FieldOptions.getDescriptor().getFile();
break;
case "google/protobuf/wrappers.proto":
fd = Int32Value.getDescriptor().getFile();
break;
case "google/protobuf/timestamp.proto":
fd = Timestamp.getDescriptor().getFile();
break;
case "google/protobuf/duration.proto":
fd = Duration.getDescriptor().getFile();
break;
case "google/protobuf/any.proto":
fd = Any.getDescriptor().getFile();
break;
case "google/protobuf/api.proto":
fd = Api.getDescriptor().getFile();
break;
case "google/protobuf/empty.proto":
fd = Empty.getDescriptor().getFile();
break;
case "google/protobuf/field_mask.proto":
fd = FieldMask.getDescriptor().getFile();
break;
case "google/protobuf/source_context.proto":
fd = SourceContext.getDescriptor().getFile();
break;
case "google/protobuf/struct.proto":
fd = Struct.getDescriptor().getFile();
break;
case "google/protobuf/type.proto":
fd = Type.getDescriptor().getFile();
break;
default:
return null;
}
return null;
outMap.put(name, fd);
return fd;
} else {
List<Descriptors.FileDescriptor> dependencies = new ArrayList<>();
if (fileDescriptorProto.getDependencyCount() > 0) {
Expand All @@ -98,7 +144,7 @@ private static Descriptors.FileDescriptor convertToFileDescriptorMap(
.forEach(
dependencyName -> {
Descriptors.FileDescriptor fileDescriptor =
convertToFileDescriptorMap(dependencyName, inMap, outMap);
convertToFileDescriptorMap(dependencyName, inMap, outMap, registry);
if (fileDescriptor != null) {
dependencies.add(fileDescriptor);
}
Expand All @@ -108,6 +154,18 @@ private static Descriptors.FileDescriptor convertToFileDescriptorMap(
Descriptors.FileDescriptor fileDescriptor =
Descriptors.FileDescriptor.buildFrom(
fileDescriptorProto, dependencies.toArray(new Descriptors.FileDescriptor[0]));
fileDescriptor
.getExtensions()
.forEach(
extension -> {
if (extension.getType() == Descriptors.FieldDescriptor.Type.MESSAGE) {
registry.add(
extension, DynamicMessage.newBuilder(extension.getMessageType()).build());
} else {
registry.add(extension);
}
});
Descriptors.FileDescriptor.internalUpdateFileDescriptor(fileDescriptor, registry);
outMap.put(name, fileDescriptor);
return fileDescriptor;
} catch (Descriptors.DescriptorValidationException e) {
Expand Down Expand Up @@ -147,10 +205,14 @@ public static ProtoDomain buildFrom(InputStream inputStream) throws IOException

private void crosswire() {
HashMap<String, DescriptorProtos.FileDescriptorProto> map = new HashMap<>();
fileDescriptorSet.getFileList().forEach(fdp -> map.put(fdp.getName(), fdp));
fileDescriptorSet.getFileList().stream()
.filter(fdp -> !fdp.getName().startsWith("google/protobuf"))
.forEach(fdp -> map.put(fdp.getName(), fdp));

ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance();
Map<String, Descriptors.FileDescriptor> outMap = new HashMap<>();
map.forEach((fileName, proto) -> convertToFileDescriptorMap(fileName, map, outMap));
map.forEach(
(fileName, proto) -> convertToFileDescriptorMap(fileName, map, outMap, extensionRegistry));
fileDescriptorMap = outMap;

indexOptionsByNumber(fileDescriptorMap.values());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@
*/
package org.apache.beam.sdk.extensions.protobuf;

import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.SCHEMA_OPTION_META_NUMBER;
import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.SCHEMA_OPTION_META_TYPE_NAME;
import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getFieldNumber;
import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMapKeyMessageName;
import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMapValueMessageName;
import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.getMessageName;
import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withFieldNumber;
import static org.apache.beam.sdk.extensions.protobuf.ProtoSchemaTranslator.withMessageName;

import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors;
Expand Down Expand Up @@ -110,28 +108,34 @@ private Object readResolve() {

Convert createConverter(Schema.Field field) {
Schema.FieldType fieldType = field.getType();
String messageName = getMessageName(fieldType);
if (messageName != null && messageName.length() > 0) {
if (fieldType.getNullable()) {
Schema.Field valueField =
Schema.Field.of("value", withFieldNumber(Schema.FieldType.BOOLEAN, 1));
switch (messageName) {
case "google.protobuf.StringValue":
case "google.protobuf.DoubleValue":
case "google.protobuf.FloatValue":
case "google.protobuf.BoolValue":
case "google.protobuf.Int64Value":
case "google.protobuf.Int32Value":
case "google.protobuf.UInt64Value":
case "google.protobuf.UInt32Value":
withFieldNumber(Schema.Field.of("value", Schema.FieldType.BOOLEAN), 1);
switch (fieldType.getTypeName()) {
case BYTE:
case INT16:
case INT32:
case INT64:
case FLOAT:
case DOUBLE:
case STRING:
case BOOLEAN:
return new WrapperConvert(field, new PrimitiveConvert(valueField));
case "google.protobuf.BytesValue":
case BYTES:
return new WrapperConvert(field, new BytesConvert(valueField));
case "google.protobuf.Timestamp":
case "google.protobuf.Duration":
// handled by logical type case
break;
case LOGICAL_TYPE:
String identifier = field.getType().getLogicalType().getIdentifier();
switch (identifier) {
case ProtoSchemaLogicalTypes.UInt32.IDENTIFIER:
case ProtoSchemaLogicalTypes.UInt64.IDENTIFIER:
return new WrapperConvert(field, new PrimitiveConvert(valueField));
default:
}
// fall through
default:
}
}

switch (fieldType.getTypeName()) {
case BYTE:
case INT16:
Expand Down Expand Up @@ -260,7 +264,8 @@ public DynamicMessage.Builder invokeNewBuilder() {

@Override
public Context getSubContext(Schema.Field field) {
String messageName = getMessageName(field.getType());
String messageName =
field.getType().getRowSchema().getOptions().getValue(SCHEMA_OPTION_META_TYPE_NAME);
return new DescriptorContext(messageName, domain);
}
}
Expand All @@ -274,9 +279,10 @@ abstract static class Convert<ValueT, InT> {
private int number;

Convert(Schema.Field field) {
try {
this.number = getFieldNumber(field.getType());
} catch (NumberFormatException e) {
Schema.Options options = field.getOptions();
if (options.hasOption(SCHEMA_OPTION_META_NUMBER)) {
this.number = options.getValue(SCHEMA_OPTION_META_NUMBER);
} else {
this.number = -1;
}
}
Expand Down Expand Up @@ -546,16 +552,8 @@ static class MapConvert extends Convert<Map, Map> {
MapConvert(ProtoDynamicMessageSchema protoSchema, Schema.Field field) {
super(field);
Schema.FieldType fieldType = field.getType();
key =
protoSchema.createConverter(
Schema.Field.of(
"KEY",
withMessageName(fieldType.getMapKeyType(), getMapKeyMessageName(fieldType))));
value =
protoSchema.createConverter(
Schema.Field.of(
"VALUE",
withMessageName(fieldType.getMapValueType(), getMapValueMessageName(fieldType))));
key = protoSchema.createConverter(Schema.Field.of("KEY", fieldType.getMapKeyType()));
value = protoSchema.createConverter(Schema.Field.of("VALUE", fieldType.getMapValueType()));
}

@Override
Expand Down Expand Up @@ -617,11 +615,7 @@ static class ArrayConvert extends Convert<List, List> {
ArrayConvert(ProtoDynamicMessageSchema protoSchema, Schema.Field field) {
super(field);
Schema.FieldType collectionElementType = field.getType().getCollectionElementType();
this.element =
protoSchema.createConverter(
Schema.Field.of(
"ELEMENT",
withMessageName(collectionElementType, getMessageName(field.getType()))));
this.element = protoSchema.createConverter(Schema.Field.of("ELEMENT", collectionElementType));
}

@Override
Expand Down Expand Up @@ -703,9 +697,11 @@ static class OneOfConvert extends Convert<OneOfType.Value, OneOfType.Value> {
super(field);
this.logicalType = (OneOfType) logicalType;
for (Schema.Field oneOfField : this.logicalType.getOneOfSchema().getFields()) {
int fieldNumber = getFieldNumber(oneOfField.getType());
int fieldNumber = getFieldNumber(oneOfField);
oneOfConvert.put(
fieldNumber, new NullableConvert(oneOfField, protoSchema.createConverter(oneOfField)));
fieldNumber,
new NullableConvert(
oneOfField, protoSchema.createConverter(oneOfField.withNullable(false))));
}
}

Expand Down
Loading