Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,18 @@ static Schema getSchema(Class<? extends Message> clazz) {
}

static Schema getSchema(Descriptors.Descriptor descriptor) {
Set<Integer> oneOfFields = Sets.newHashSet();
/* OneOfComponentFields refers to the field number in the protobuf where the component subfields
* are. This is needed to prevent double inclusion of the component fields.*/
Set<Integer> oneOfComponentFields = Sets.newHashSet();
/* OneOfFieldLocation stores the field number of the first field in the OneOf. Using this, we can use the location
of the first field in the OneOf as the location of the entire OneOf.*/
Map<Integer, Field> oneOfFieldLocation = Maps.newHashMap();
List<Field> fields = Lists.newArrayListWithCapacity(descriptor.getFields().size());
for (OneofDescriptor oneofDescriptor : descriptor.getOneofs()) {
List<Field> subFields = Lists.newArrayListWithCapacity(oneofDescriptor.getFieldCount());
Map<String, Integer> enumIds = Maps.newHashMap();
for (FieldDescriptor fieldDescriptor : oneofDescriptor.getFields()) {
oneOfFields.add(fieldDescriptor.getNumber());
oneOfComponentFields.add(fieldDescriptor.getNumber());
// Store proto field number in a field option.
FieldType fieldType = beamFieldTypeFromProtoField(fieldDescriptor);
subFields.add(
Expand All @@ -172,17 +177,26 @@ static Schema getSchema(Descriptors.Descriptor descriptor) {
enumIds.putIfAbsent(fieldDescriptor.getName(), fieldDescriptor.getNumber()) == null);
}
FieldType oneOfType = FieldType.logicalType(OneOfType.create(subFields, enumIds));
fields.add(Field.of(oneofDescriptor.getName(), oneOfType));
oneOfFieldLocation.put(
oneofDescriptor.getFields().get(0).getNumber(),
Field.of(oneofDescriptor.getName(), oneOfType));
}

for (Descriptors.FieldDescriptor fieldDescriptor : descriptor.getFields()) {
if (!oneOfFields.contains(fieldDescriptor.getNumber())) {
int fieldDescriptorNumber = fieldDescriptor.getNumber();
if (!oneOfComponentFields.contains(fieldDescriptorNumber)) {
// Store proto field number in metadata.
FieldType fieldType = beamFieldTypeFromProtoField(fieldDescriptor);
fields.add(
withFieldNumber(
Field.of(fieldDescriptor.getName(), fieldType), fieldDescriptor.getNumber())
withFieldNumber(Field.of(fieldDescriptor.getName(), fieldType), fieldDescriptorNumber)
.withOptions(getFieldOptions(fieldDescriptor)));
/* Note that descriptor.getFields() returns an iterator in the order of the fields in the .proto file, not
* in field number order. Therefore we can safely insert the OneOfField at the field of its first component.*/
} else {
Field oneOfField = oneOfFieldLocation.get(fieldDescriptorNumber);
if (oneOfField != null) {
fields.add(oneOfField);
}
}
}
return Schema.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_PROTO;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_ROW;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_SCHEMA;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NONCONTIGUOUS_ONEOF_PROTO;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NONCONTIGUOUS_ONEOF_ROW;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NONCONTIGUOUS_ONEOF_SCHEMA;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NULL_MAP_PRIMITIVE_PROTO;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NULL_MAP_PRIMITIVE_ROW;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NULL_REPEATED_PROTO;
Expand All @@ -45,6 +48,15 @@
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_PROTO;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_ROW;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REPEATED_SCHEMA;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_PROTO_BOOL;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_PROTO_INT32;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_PROTO_PRIMITIVE;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_PROTO_STRING;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_ROW_BOOL;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_ROW_INT32;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_ROW_PRIMITIVE;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_ROW_STRING;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_SCHEMA;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_PROTO;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_ROW;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_SCHEMA;
Expand All @@ -61,10 +73,12 @@
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.EnumMessage;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.MapPrimitive;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.Nested;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.NonContiguousOneOf;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.OneOf;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.OuterOneOf;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.Primitive;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.RepeatPrimitive;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.ReversedOneOf;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.WktMessage;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType;
Expand Down Expand Up @@ -256,6 +270,78 @@ public void testOneOfRowToProto() {
assertEquals(ONEOF_PROTO_PRIMITIVE.toString(), fromRow.apply(ONEOF_ROW_PRIMITIVE).toString());
}

@Test
public void testReversedOneOfSchema() {
ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(ReversedOneOf.getDescriptor());
Schema schema = schemaProvider.getSchema();
assertEquals(REVERSED_ONEOF_SCHEMA, schema);
}

@Test
public void testReversedOneOfProtoToRow() throws InvalidProtocolBufferException {
ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(ReversedOneOf.getDescriptor());
SerializableFunction<DynamicMessage, Row> toRow = schemaProvider.getToRowFunction();
// equality doesn't work between dynamic messages and other,
// so we compare string representation
assertEquals(
REVERSED_ONEOF_ROW_INT32.toString(),
toRow.apply(toDynamic(REVERSED_ONEOF_PROTO_INT32)).toString());
assertEquals(
REVERSED_ONEOF_ROW_BOOL.toString(),
toRow.apply(toDynamic(REVERSED_ONEOF_PROTO_BOOL)).toString());
assertEquals(
REVERSED_ONEOF_ROW_STRING.toString(),
toRow.apply(toDynamic(REVERSED_ONEOF_PROTO_STRING)).toString());
assertEquals(
REVERSED_ONEOF_ROW_PRIMITIVE.toString(),
toRow.apply(toDynamic(REVERSED_ONEOF_PROTO_PRIMITIVE)).toString());
}

@Test
public void testReversedOneOfRowToProto() {
ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(ReversedOneOf.getDescriptor());
SerializableFunction<Row, DynamicMessage> fromRow = schemaProvider.getFromRowFunction();
assertEquals(
REVERSED_ONEOF_PROTO_INT32.toString(), fromRow.apply(REVERSED_ONEOF_ROW_INT32).toString());
assertEquals(
REVERSED_ONEOF_PROTO_BOOL.toString(), fromRow.apply(REVERSED_ONEOF_ROW_BOOL).toString());
assertEquals(
REVERSED_ONEOF_PROTO_STRING.toString(),
fromRow.apply(REVERSED_ONEOF_ROW_STRING).toString());
assertEquals(
REVERSED_ONEOF_PROTO_PRIMITIVE.toString(),
fromRow.apply(REVERSED_ONEOF_ROW_PRIMITIVE).toString());
}

@Test
public void testNonContiguousOneOfSchema() {
ProtoDynamicMessageSchema schemaProvider =
schemaFromDescriptor(NonContiguousOneOf.getDescriptor());
Schema schema = schemaProvider.getSchema();
assertEquals(NONCONTIGUOUS_ONEOF_SCHEMA, schema);
}

@Test
public void testNonContiguousOneOfProtoToRow() throws InvalidProtocolBufferException {
ProtoDynamicMessageSchema schemaProvider =
schemaFromDescriptor(NonContiguousOneOf.getDescriptor());
SerializableFunction<DynamicMessage, Row> toRow = schemaProvider.getToRowFunction();
// equality doesn't work between dynamic messages and other,
// so we compare string representation
assertEquals(
NONCONTIGUOUS_ONEOF_ROW.toString(),
toRow.apply(toDynamic(NONCONTIGUOUS_ONEOF_PROTO)).toString());
}

@Test
public void testNonContiguousOneOfRowToProto() {
ProtoDynamicMessageSchema schemaProvider =
schemaFromDescriptor(NonContiguousOneOf.getDescriptor());
SerializableFunction<Row, DynamicMessage> fromRow = schemaProvider.getFromRowFunction();
assertEquals(
NONCONTIGUOUS_ONEOF_PROTO.toString(), fromRow.apply(NONCONTIGUOUS_ONEOF_ROW).toString());
}

@Test
public void testOuterOneOfSchema() {
ProtoDynamicMessageSchema schemaProvider = schemaFromDescriptor(OuterOneOf.getDescriptor());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_PROTO;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_ROW;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NESTED_SCHEMA;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NONCONTIGUOUS_ONEOF_PROTO;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NONCONTIGUOUS_ONEOF_ROW;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NULL_MAP_PRIMITIVE_PROTO;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NULL_MAP_PRIMITIVE_ROW;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.NULL_REPEATED_PROTO;
Expand Down Expand Up @@ -51,6 +53,14 @@
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REQUIRED_PRIMITIVE_PROTO;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REQUIRED_PRIMITIVE_ROW;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REQUIRED_PRIMITIVE_SCHEMA;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_PROTO_BOOL;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_PROTO_INT32;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_PROTO_PRIMITIVE;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_PROTO_STRING;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_ROW_BOOL;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_ROW_INT32;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_ROW_PRIMITIVE;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.REVERSED_ONEOF_ROW_STRING;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_PROTO;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_ROW;
import static org.apache.beam.sdk.extensions.protobuf.TestProtoSchemas.WKT_MESSAGE_SCHEMA;
Expand All @@ -64,10 +74,12 @@
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.EnumMessage;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.MapPrimitive;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.Nested;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.NonContiguousOneOf;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.OneOf;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.OuterOneOf;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.Primitive;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.RepeatPrimitive;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.ReversedOneOf;
import org.apache.beam.sdk.extensions.protobuf.Proto3SchemaMessages.WktMessage;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
Expand Down Expand Up @@ -279,6 +291,40 @@ public void testOuterOneOfRowToProto() {
assertEquals(OUTER_ONEOF_PROTO, fromRow.apply(OUTER_ONEOF_ROW));
}

@Test
public void testReversedOneOfProtoToRow() {
SerializableFunction<ReversedOneOf, Row> toRow =
new ProtoMessageSchema().toRowFunction(TypeDescriptor.of(ReversedOneOf.class));
assertEquals(REVERSED_ONEOF_ROW_INT32, toRow.apply(REVERSED_ONEOF_PROTO_INT32));
assertEquals(REVERSED_ONEOF_ROW_BOOL, toRow.apply(REVERSED_ONEOF_PROTO_BOOL));
assertEquals(REVERSED_ONEOF_ROW_STRING, toRow.apply(REVERSED_ONEOF_PROTO_STRING));
assertEquals(REVERSED_ONEOF_ROW_PRIMITIVE, toRow.apply(REVERSED_ONEOF_PROTO_PRIMITIVE));
}

@Test
public void testReversedOneOfRowToProto() {
SerializableFunction<Row, ReversedOneOf> fromRow =
new ProtoMessageSchema().fromRowFunction(TypeDescriptor.of(ReversedOneOf.class));
assertEquals(REVERSED_ONEOF_PROTO_INT32, fromRow.apply(REVERSED_ONEOF_ROW_INT32));
assertEquals(REVERSED_ONEOF_PROTO_BOOL, fromRow.apply(REVERSED_ONEOF_ROW_BOOL));
assertEquals(REVERSED_ONEOF_PROTO_STRING, fromRow.apply(REVERSED_ONEOF_ROW_STRING));
assertEquals(REVERSED_ONEOF_PROTO_PRIMITIVE, fromRow.apply(REVERSED_ONEOF_ROW_PRIMITIVE));
}

@Test
public void testNonContiguousOneOfProtoToRow() {
SerializableFunction<NonContiguousOneOf, Row> toRow =
new ProtoMessageSchema().toRowFunction(TypeDescriptor.of(NonContiguousOneOf.class));
assertEquals(NONCONTIGUOUS_ONEOF_ROW, toRow.apply(NONCONTIGUOUS_ONEOF_PROTO));
}

@Test
public void testNonContiguousOneOfRowToProto() {
SerializableFunction<Row, NonContiguousOneOf> fromRow =
new ProtoMessageSchema().fromRowFunction(TypeDescriptor.of(NonContiguousOneOf.class));
assertEquals(NONCONTIGUOUS_ONEOF_PROTO, fromRow.apply(NONCONTIGUOUS_ONEOF_ROW));
}

private static final EnumerationType ENUM_TYPE =
EnumerationType.create(ImmutableMap.of("ZERO", 0, "TWO", 2, "THREE", 3));
private static final Schema ENUM_SCHEMA =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ public void testOneOfSchema() {
ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.OneOf.class));
}

@Test
public void testReversedOneOfSchema() {
assertEquals(
TestProtoSchemas.REVERSED_ONEOF_SCHEMA,
ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.ReversedOneOf.class));
}

@Test
public void testNonContiguousOneOfSchema() {
assertEquals(
TestProtoSchemas.NONCONTIGUOUS_ONEOF_SCHEMA,
ProtoSchemaTranslator.getSchema(Proto3SchemaMessages.NonContiguousOneOf.class));
}

@Test
public void testNestedOneOfSchema() {
assertEquals(
Expand Down
Loading