Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import java.time.LocalTime;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
Expand Down Expand Up @@ -221,11 +220,18 @@ private static TableFieldSchema fieldDescriptorFromBeamField(Field field) {
case ITERABLE:
@Nullable FieldType elementType = field.getType().getCollectionElementType();
if (elementType == null) {
throw new RuntimeException("Unexpected null element type!");
throw new RuntimeException("Unexpected null element type on " + field.getName());
}
TypeName containedTypeName =
Preconditions.checkNotNull(
elementType.getTypeName(),
"Null type name found in contained type at " + field.getName());
Preconditions.checkState(
!Preconditions.checkNotNull(elementType.getTypeName()).isCollectionType(),
"Nested arrays not supported by BigQuery.");
!(containedTypeName.isCollectionType() || containedTypeName.isMapType()),
"Nested container types are not supported by BigQuery. Field "
+ field.getName()
+ " contains a type "
+ containedTypeName.name());
TableFieldSchema elementFieldSchema =
fieldDescriptorFromBeamField(Field.of(field.getName(), elementType));
builder = builder.setType(elementFieldSchema.getType());
Expand All @@ -244,7 +250,24 @@ private static TableFieldSchema fieldDescriptorFromBeamField(Field field) {
builder = builder.setType(type);
break;
case MAP:
throw new RuntimeException("Map types not supported by BigQuery.");
@Nullable FieldType keyType = field.getType().getMapKeyType();
@Nullable FieldType valueType = field.getType().getMapValueType();
if (keyType == null) {
throw new RuntimeException(
"Unexpected null element type for the map's key on " + field.getName());
}
if (valueType == null) {
throw new RuntimeException(
"Unexpected null element type for the map's value on " + field.getName());
}

builder =
builder
.setType(TableFieldSchema.Type.STRUCT)
.addFields(fieldDescriptorFromBeamField(Field.of("key", keyType)))
.addFields(fieldDescriptorFromBeamField(Field.of("value", valueType)))
.setMode(TableFieldSchema.Mode.REPEATED);
break;
default:
@Nullable
TableFieldSchema.Type primitiveType = PRIMITIVE_TYPES.get(field.getType().getTypeName());
Expand Down Expand Up @@ -289,25 +312,34 @@ private static Object toProtoValue(
case ROW:
return messageFromBeamRow(fieldDescriptor.getMessageType(), (Row) value, null, -1);
case ARRAY:
List<Object> list = (List<Object>) value;
@Nullable FieldType arrayElementType = beamFieldType.getCollectionElementType();
if (arrayElementType == null) {
throw new RuntimeException("Unexpected null element type!");
}
return list.stream()
.map(v -> toProtoValue(fieldDescriptor, arrayElementType, v))
.collect(Collectors.toList());
case ITERABLE:
Iterable<Object> iterable = (Iterable<Object>) value;
@Nullable FieldType iterableElementType = beamFieldType.getCollectionElementType();
if (iterableElementType == null) {
throw new RuntimeException("Unexpected null element type!");
throw new RuntimeException("Unexpected null element type: " + fieldDescriptor.getName());
}

return StreamSupport.stream(iterable.spliterator(), false)
.map(v -> toProtoValue(fieldDescriptor, iterableElementType, v))
.collect(Collectors.toList());
case MAP:
throw new RuntimeException("Map types not supported by BigQuery.");
Map<Object, Object> map = (Map<Object, Object>) value;
@Nullable FieldType keyType = beamFieldType.getMapKeyType();
@Nullable FieldType valueType = beamFieldType.getMapValueType();
if (keyType == null) {
throw new RuntimeException("Unexpected null for key type: " + fieldDescriptor.getName());
}
if (valueType == null) {
throw new RuntimeException(
"Unexpected null for value type: " + fieldDescriptor.getName());
}

return map.entrySet().stream()
.map(
(Map.Entry<Object, Object> entry) ->
mapEntryToProtoValue(
fieldDescriptor.getMessageType(), keyType, valueType, entry))
.collect(Collectors.toList());
default:
return scalarToProtoValue(beamFieldType, value);
}
Expand Down Expand Up @@ -337,6 +369,28 @@ static Object scalarToProtoValue(FieldType beamFieldType, Object value) {
}
}

static Object mapEntryToProtoValue(
Descriptor descriptor,
FieldType keyFieldType,
FieldType valueFieldType,
Map.Entry<Object, Object> entryValue) {
DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor);
FieldDescriptor keyFieldDescriptor =
Preconditions.checkNotNull(descriptor.findFieldByName("key"));
@Nullable Object key = toProtoValue(keyFieldDescriptor, keyFieldType, entryValue.getKey());
if (key != null) {
builder.setField(keyFieldDescriptor, key);
}
FieldDescriptor valueFieldDescriptor =
Preconditions.checkNotNull(descriptor.findFieldByName("value"));
@Nullable
Object value = toProtoValue(valueFieldDescriptor, valueFieldType, entryValue.getValue());
if (value != null) {
builder.setField(valueFieldDescriptor, value);
}
return builder.build();
}

static ByteString serializeBigDecimalToNumeric(BigDecimal o) {
return serializeBigDecimal(o, NUMERIC_SCALE, MAX_NUMERIC_VALUE, MIN_NUMERIC_VALUE, "Numeric");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

import com.google.protobuf.ByteString;
import com.google.protobuf.DescriptorProtos.DescriptorProto;
Expand All @@ -36,8 +37,11 @@
import java.time.LocalTime;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.FieldType;
Expand Down Expand Up @@ -284,12 +288,14 @@ public class BeamRowToStorageApiProtoTest {
.addField("nested", FieldType.row(BASE_SCHEMA).withNullable(true))
.addField("nestedArray", FieldType.array(FieldType.row(BASE_SCHEMA)))
.addField("nestedIterable", FieldType.iterable(FieldType.row(BASE_SCHEMA)))
.addField("nestedMap", FieldType.map(FieldType.STRING, FieldType.row(BASE_SCHEMA)))
.build();
private static final Row NESTED_ROW =
Row.withSchema(NESTED_SCHEMA)
.withFieldValue("nested", BASE_ROW)
.withFieldValue("nestedArray", ImmutableList.of(BASE_ROW, BASE_ROW))
.withFieldValue("nestedIterable", ImmutableList.of(BASE_ROW, BASE_ROW))
.withFieldValue("nestedMap", ImmutableMap.of("key1", BASE_ROW, "key2", BASE_ROW))
.build();

@Test
Expand Down Expand Up @@ -347,12 +353,12 @@ public void testNestedFromSchema() {
.collect(
Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getLabel));

assertEquals(3, types.size());
assertEquals(4, types.size());

Map<String, DescriptorProto> nestedTypes =
descriptor.getNestedTypeList().stream()
.collect(Collectors.toMap(DescriptorProto::getName, Functions.identity()));
assertEquals(3, nestedTypes.size());
assertEquals(4, nestedTypes.size());
assertEquals(Type.TYPE_MESSAGE, types.get("nested"));
assertEquals(Label.LABEL_OPTIONAL, typeLabels.get("nested"));
String nestedTypeName1 = typeNames.get("nested");
Expand All @@ -379,6 +385,87 @@ public void testNestedFromSchema() {
.collect(
Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType));
assertEquals(expectedBaseTypes, nestedTypes3);

assertEquals(Type.TYPE_MESSAGE, types.get("nestedmap"));
assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmap"));
String nestedTypeName4 = typeNames.get("nestedmap");
// expects 2 fields in the nested map, key and value
assertEquals(2, nestedTypes.get(nestedTypeName4).getFieldList().size());
Supplier<Stream<FieldDescriptorProto>> stream =
() -> nestedTypes.get(nestedTypeName4).getFieldList().stream();
assertTrue(stream.get().anyMatch(fdp -> fdp.getName().equals("key")));
assertTrue(stream.get().anyMatch(fdp -> fdp.getName().equals("value")));

Map<String, Type> nestedTypes4 =
nestedTypes.get(nestedTypeName4).getNestedTypeList().stream()
.flatMap(vdesc -> vdesc.getFieldList().stream())
.collect(
Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType));
assertEquals(expectedBaseTypes, nestedTypes4);
}

@Test
public void testParticularMapsFromSchemas() {
Schema nestedMapSchemaVariations =
Schema.builder()
.addField(
"nestedMultiMap",
FieldType.map(FieldType.STRING, FieldType.array(FieldType.STRING)))
.addField(
"nestedMapNullable",
FieldType.map(FieldType.STRING, FieldType.DOUBLE).withNullable(true))
.build();

DescriptorProto descriptor =
TableRowToStorageApiProto.descriptorSchemaFromTableSchema(
BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema((nestedMapSchemaVariations)),
true,
false);

Map<String, Type> types =
descriptor.getFieldList().stream()
.collect(
Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getType));
Map<String, String> typeNames =
descriptor.getFieldList().stream()
.collect(
Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getTypeName));
Map<String, Label> typeLabels =
descriptor.getFieldList().stream()
.collect(
Collectors.toMap(FieldDescriptorProto::getName, FieldDescriptorProto::getLabel));

Map<String, DescriptorProto> nestedTypes =
descriptor.getNestedTypeList().stream()
.collect(Collectors.toMap(DescriptorProto::getName, Functions.identity()));
assertEquals(2, nestedTypes.size());

assertEquals(Type.TYPE_MESSAGE, types.get("nestedmultimap"));
assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmultimap"));
String nestedMultiMapName = typeNames.get("nestedmultimap");
// expects 2 fields for the nested array of maps, key and value
assertEquals(2, nestedTypes.get(nestedMultiMapName).getFieldList().size());
Supplier<Stream<FieldDescriptorProto>> stream =
() -> nestedTypes.get(nestedMultiMapName).getFieldList().stream();
assertTrue(stream.get().filter(fdp -> fdp.getName().equals("key")).count() == 1);
assertTrue(stream.get().filter(fdp -> fdp.getName().equals("value")).count() == 1);
assertTrue(
stream
.get()
.filter(fdp -> fdp.getName().equals("value"))
.filter(fdp -> fdp.getLabel().equals(Label.LABEL_REPEATED))
.count()
== 1);

assertEquals(Type.TYPE_MESSAGE, types.get("nestedmapnullable"));
// even though the field is marked as optional in the row we will should see repeated in proto
assertEquals(Label.LABEL_REPEATED, typeLabels.get("nestedmapnullable"));
String nestedMapNullableName = typeNames.get("nestedmapnullable");
// expects 2 fields in the nullable maps, key and value
assertEquals(2, nestedTypes.get(nestedMapNullableName).getFieldList().size());
stream = () -> nestedTypes.get(nestedMapNullableName).getFieldList().stream();
assertTrue(stream.get().filter(fdp -> fdp.getName().equals("key")).count() == 1);
assertTrue(stream.get().filter(fdp -> fdp.getName().equals("value")).count() == 1);
}

private void assertBaseRecord(DynamicMessage msg) {
Expand All @@ -395,7 +482,7 @@ public void testMessageFromTableRow() throws Exception {
BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema(NESTED_SCHEMA), true, false);
DynamicMessage msg =
BeamRowToStorageApiProto.messageFromBeamRow(descriptor, NESTED_ROW, null, -1);
assertEquals(3, msg.getAllFields().size());
assertEquals(4, msg.getAllFields().size());

Map<String, FieldDescriptor> fieldDescriptors =
descriptor.getFields().stream()
Expand All @@ -404,6 +491,63 @@ public void testMessageFromTableRow() throws Exception {
assertBaseRecord(nestedMsg);
}

@Test
public void testMessageFromTableRowForArraysAndMaps() throws Exception {
Schema nestedMapSchemaVariations =
Schema.builder()
.addField("nestedArrayNullable", FieldType.array(FieldType.STRING).withNullable(true))
.addField("nestedMap", FieldType.map(FieldType.STRING, FieldType.STRING))
.addField(
"nestedMultiMap",
FieldType.map(FieldType.STRING, FieldType.iterable(FieldType.STRING)))
.addField(
"nestedMapNullable",
FieldType.map(FieldType.STRING, FieldType.DOUBLE).withNullable(true))
.build();

Row nestedRow =
Row.withSchema(nestedMapSchemaVariations)
.withFieldValue("nestedArrayNullable", null)
.withFieldValue("nestedMap", ImmutableMap.of("key1", "value1"))
.withFieldValue(
"nestedMultiMap",
ImmutableMap.of("multikey1", ImmutableList.of("multivalue1", "multivalue2")))
.withFieldValue("nestedMapNullable", null)
.build();

Descriptor descriptor =
TableRowToStorageApiProto.getDescriptorFromTableSchema(
BeamRowToStorageApiProto.protoTableSchemaFromBeamSchema(nestedMapSchemaVariations),
true,
false);
DynamicMessage msg =
BeamRowToStorageApiProto.messageFromBeamRow(descriptor, nestedRow, null, -1);

Map<String, FieldDescriptor> fieldDescriptors =
descriptor.getFields().stream()
.collect(Collectors.toMap(FieldDescriptor::getName, Functions.identity()));

DynamicMessage nestedMapEntryMsg =
(DynamicMessage) msg.getRepeatedField(fieldDescriptors.get("nestedmap"), 0);
String value =
(String)
nestedMapEntryMsg.getField(
fieldDescriptors.get("nestedmap").getMessageType().findFieldByName("value"));
assertEquals("value1", value);

DynamicMessage nestedMultiMapEntryMsg =
(DynamicMessage) msg.getRepeatedField(fieldDescriptors.get("nestedmultimap"), 0);
List<String> values =
(List<String>)
nestedMultiMapEntryMsg.getField(
fieldDescriptors.get("nestedmultimap").getMessageType().findFieldByName("value"));
assertTrue(values.size() == 2);
assertEquals("multivalue1", values.get(0));

assertTrue(msg.getRepeatedFieldCount(fieldDescriptors.get("nestedarraynullable")) == 0);
assertTrue(msg.getRepeatedFieldCount(fieldDescriptors.get("nestedmapnullable")) == 0);
}

@Test
public void testCdcFields() throws Exception {
Descriptor descriptor =
Expand All @@ -413,7 +557,7 @@ public void testCdcFields() throws Exception {
assertNotNull(descriptor.findFieldByName(StorageApiCDC.CHANGE_SQN_COLUMN));
DynamicMessage msg =
BeamRowToStorageApiProto.messageFromBeamRow(descriptor, NESTED_ROW, "UPDATE", 42);
assertEquals(5, msg.getAllFields().size());
assertEquals(6, msg.getAllFields().size());

Map<String, FieldDescriptor> fieldDescriptors =
descriptor.getFields().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,18 @@ public void testToTableSchema_map() {
assertThat(field.getFields(), containsInAnyOrder(MAP_KEY, MAP_VALUE));
}

@Test
public void testToTableSchema_map_array() {
TableSchema schema = toTableSchema(MAP_ARRAY_TYPE);

assertThat(schema.getFields().size(), equalTo(1));
TableFieldSchema field = schema.getFields().get(0);
assertThat(field.getName(), equalTo("map"));
assertThat(field.getType(), equalTo(StandardSQLTypeName.STRUCT.toString()));
assertThat(field.getMode(), equalTo(Mode.REPEATED.toString()));
assertThat(field.getFields(), containsInAnyOrder(MAP_KEY, MAP_VALUE));
}

@Test
public void testToTableRow_flat() {
TableRow row = toTableRow().apply(FLAT_ROW);
Expand Down