diff --git a/model/pipeline/src/main/proto/schema.proto b/model/pipeline/src/main/proto/schema.proto index 2cf404e3341f..dcf75ca4a5ee 100644 --- a/model/pipeline/src/main/proto/schema.proto +++ b/model/pipeline/src/main/proto/schema.proto @@ -31,6 +31,7 @@ option java_outer_classname = "SchemaApi"; message Schema { repeated Field fields = 1; string id = 2; + repeated Option options = 3; } message Field { @@ -39,6 +40,7 @@ message Field { FieldType type = 3; int32 id = 4; int32 encoding_position = 5; + repeated Option options = 6; } message FieldType { @@ -91,6 +93,12 @@ message LogicalType { FieldValue argument = 5; } +message Option { + string name = 1; + FieldType type = 2; + FieldValue value = 3; +} + message Row { repeated FieldValue values = 1; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java index 81b7922fe495..d7abf4e8c93e 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java @@ -163,7 +163,7 @@ public RowCoder fromComponents(List> components, byte[] payload) { components.isEmpty(), "Expected empty component list, but received: " + components); Schema schema; try { - schema = SchemaTranslation.fromProto(SchemaApi.Schema.parseFrom(payload)); + schema = SchemaTranslation.schemaFromProto(SchemaApi.Schema.parseFrom(payload)); } catch (InvalidProtocolBufferException e) { throw new RuntimeException("Unable to parse schema for RowCoder: ", e); } diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java index 4ec1a7d58a0c..2ac9db5ee364 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java @@ -302,7 +302,8 @@ private static Object convertValue(Object value, CommonCoder coderSpec, Coder co } else if (s.equals(getUrn(StandardCoders.Enum.ROW))) { Schema schema; try { - schema = SchemaTranslation.fromProto(SchemaApi.Schema.parseFrom(coderSpec.getPayload())); + schema = + SchemaTranslation.schemaFromProto(SchemaApi.Schema.parseFrom(coderSpec.getPayload())); } catch (InvalidProtocolBufferException e) { throw new RuntimeException("Failed to parse schema payload for row coder", e); } diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SchemaTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SchemaTranslationTest.java index 6022d70999cc..bfb9adb4988b 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SchemaTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SchemaTranslationTest.java @@ -20,12 +20,17 @@ import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertThat; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.apache.beam.model.pipeline.v1.SchemaApi; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.SchemaTranslation; import org.apache.beam.sdk.schemas.logicaltypes.FixedBytes; +import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.junit.Test; import org.junit.runner.RunWith; @@ -41,6 +46,37 @@ public class SchemaTranslationTest { public static class ToFromProtoTest { @Parameters(name = "{index}: {0}") public static Iterable data() { + Map optionMap = new HashMap<>(); + optionMap.put("string", 42); + List optionList = new ArrayList<>(); + optionList.add("string"); + Row optionRow = + Row.withSchema( + Schema.builder() + .addField("field_one", FieldType.STRING) + .addField("field_two", FieldType.INT32) + .build()) + .addValue("value") + .addValue(42) + .build(); + + Schema.Options.Builder optionsBuilder = + Schema.Options.builder() + .setOption("field_option_boolean", FieldType.BOOLEAN, true) + .setOption("field_option_byte", FieldType.BYTE, (byte) 12) + .setOption("field_option_int16", FieldType.INT16, (short) 12) + .setOption("field_option_int32", FieldType.INT32, 12) + .setOption("field_option_int64", FieldType.INT64, 12L) + .setOption("field_option_string", FieldType.STRING, "foo") + .setOption("field_option_bytes", FieldType.BYTES, new byte[] {0x42, 0x69, 0x00}) + .setOption("field_option_float", FieldType.FLOAT, (float) 12.0) + .setOption("field_option_double", FieldType.DOUBLE, 12.0) + .setOption( + "field_option_map", FieldType.map(FieldType.STRING, FieldType.INT32), optionMap) + .setOption("field_option_array", FieldType.array(FieldType.STRING), optionList) + .setRowOption("field_option_row", optionRow) + .setOption("field_option_value", FieldType.STRING, "other"); + return ImmutableList.builder() .add(Schema.of(Field.of("string", FieldType.STRING))) .add( @@ -77,6 +113,49 @@ public static Iterable data() { Schema.of( Field.of("decimal", FieldType.DECIMAL), Field.of("datetime", FieldType.DATETIME))) .add(Schema.of(Field.of("logical", FieldType.logicalType(FixedBytes.of(24))))) + .add( + Schema.of( + Field.of("field_with_option_atomic", FieldType.STRING) + .withOptions( + Schema.Options.builder() + .setOption( + "field_option_atomic", FieldType.INT32, Integer.valueOf(42)) + .build())) + .withOptions( + Schema.Options.builder() + .setOption("schema_option_atomic", FieldType.BOOLEAN, true))) + .add( + Schema.of( + Field.of("field_with_option_map", FieldType.STRING) + .withOptions( + Schema.Options.builder() + .setOption( + "field_option_map", + FieldType.map(FieldType.STRING, FieldType.INT32), + optionMap))) + .withOptions( + Schema.Options.builder() + .setOption( + "field_option_map", + FieldType.map(FieldType.STRING, FieldType.INT32), + optionMap))) + .add( + Schema.of( + Field.of("field_with_option_array", FieldType.STRING) + .withOptions( + Schema.Options.builder() + .setOption( + "field_option_array", + FieldType.array(FieldType.STRING), + optionList) + .build())) + .withOptions( + Schema.Options.builder() + .setOption( + "field_option_array", FieldType.array(FieldType.STRING), optionList))) + .add( + Schema.of(Field.of("field", FieldType.STRING).withOptions(optionsBuilder)) + .withOptions(optionsBuilder)) .build(); } @@ -87,7 +166,7 @@ public static Iterable data() { public void toAndFromProto() throws Exception { SchemaApi.Schema schemaProto = SchemaTranslation.schemaToProto(schema, true); - Schema decodedSchema = SchemaTranslation.fromProto(schemaProto); + Schema decodedSchema = SchemaTranslation.schemaFromProto(schemaProto); assertThat(decodedSchema, equalTo(schema)); } } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/SchemaCoderCloudObjectTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/SchemaCoderCloudObjectTranslator.java index 8987abb37ca7..4ea07178228d 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/SchemaCoderCloudObjectTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/SchemaCoderCloudObjectTranslator.java @@ -91,7 +91,7 @@ public SchemaCoder fromCloudObject(CloudObject cloudObject) { SchemaApi.Schema protoSchema = SchemaApi.Schema.parseFrom( StringUtils.jsonStringToByteArray(Structs.getString(cloudObject, SCHEMA))); - Schema schema = SchemaTranslation.fromProto(protoSchema); + Schema schema = SchemaTranslation.schemaFromProto(protoSchema); return SchemaCoder.of(schema, typeDescriptor, toRowFunction, fromRowFunction); } catch (IOException e) { throw new RuntimeException(e); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java index 67eb832f13ba..5689ffc813f7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.schemas; +import static org.apache.beam.sdk.values.SchemaVerification.verifyFieldValue; + import com.google.auto.value.AutoValue; import java.io.Serializable; import java.nio.charset.StandardCharsets; @@ -24,10 +26,12 @@ import java.util.Arrays; import java.util.Collections; import java.util.Comparator; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.TreeMap; import java.util.UUID; import java.util.stream.Collector; import java.util.stream.Collectors; @@ -36,6 +40,7 @@ import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.BiMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashBiMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; @@ -93,9 +98,12 @@ public String toString() { // equal, so we can short circuit comparison. @Nullable private UUID uuid = null; + private final Options options; + /** Builder class for building {@link Schema} objects. */ public static class Builder { List fields; + Options options = Options.none(); public Builder() { this.fields = Lists.newArrayList(); @@ -206,12 +214,23 @@ public Builder addMapField(String name, FieldType keyType, FieldType valueType) return this; } + /** Returns a copy of the Field with isNullable set. */ + public Builder setOptions(Options options) { + this.options = options; + return this; + } + + public Builder setOptions(Options.Builder optionsBuilder) { + this.options = optionsBuilder.build(); + return this; + } + public int getLastFieldId() { return fields.size() - 1; } public Schema build() { - return new Schema(fields); + return new Schema(fields, options); } } @@ -220,23 +239,37 @@ public static Builder builder() { } public Schema(List fields) { + this(fields, Options.none()); + } + + public Schema(List fields, Options options) { this.fields = fields; int index = 0; for (Field field : fields) { - if (fieldIndices.get(field.getName()) != null) { - throw new IllegalArgumentException( - "Duplicate field " + field.getName() + " added to schema"); - } + Preconditions.checkArgument( + fieldIndices.get(field.getName()) == null, + "Duplicate field " + field.getName() + " added to schema"); encodingPositions.put(field.getName(), index); fieldIndices.put(field.getName(), index++); } this.hashCode = Objects.hash(fieldIndices, fields); + this.options = options; } public static Schema of(Field... fields) { return Schema.builder().addFields(fields).build(); } + /** Returns a copy of the Schema with the options set. */ + public Schema withOptions(Options options) { + return new Schema(fields, getOptions().toBuilder().addOptions(options).build()); + } + + /** Returns a copy of the Schema with the options set. */ + public Schema withOptions(Options.Builder optionsBuilder) { + return withOptions(optionsBuilder.build()); + } + /** Set this schema's UUID. All schemas with the same UUID must be guaranteed to be identical. */ public void setUUID(UUID uuid) { this.uuid = uuid; @@ -273,7 +306,8 @@ public boolean equals(Object o) { return Objects.equals(uuid, other.uuid); } return Objects.equals(fieldIndices, other.fieldIndices) - && Objects.equals(getFields(), other.getFields()); + && Objects.equals(getFields(), other.getFields()) + && Objects.equals(getOptions(), other.getOptions()); } /** Returns true if two schemas are equal ignoring field names and descriptions. */ @@ -350,6 +384,8 @@ public String toString() { builder.append(field); builder.append(System.lineSeparator()); } + builder.append("Options:"); + builder.append(options); return builder.toString(); }; @@ -637,24 +673,24 @@ public static FieldType of(TypeName typeName) { public static final FieldType DATETIME = FieldType.of(TypeName.DATETIME); /** Create an array type for the given field type. */ - public static final FieldType array(FieldType elementType) { + public static FieldType array(FieldType elementType) { return FieldType.forTypeName(TypeName.ARRAY).setCollectionElementType(elementType).build(); } /** @deprecated Set the nullability on the elementType instead */ @Deprecated - public static final FieldType array(FieldType elementType, boolean nullable) { + public static FieldType array(FieldType elementType, boolean nullable) { return FieldType.forTypeName(TypeName.ARRAY) .setCollectionElementType(elementType.withNullable(nullable)) .build(); } - public static final FieldType iterable(FieldType elementType) { + public static FieldType iterable(FieldType elementType) { return FieldType.forTypeName(TypeName.ITERABLE).setCollectionElementType(elementType).build(); } /** Create a map type for the given key and value types. */ - public static final FieldType map(FieldType keyType, FieldType valueType) { + public static FieldType map(FieldType keyType, FieldType valueType) { return FieldType.forTypeName(TypeName.MAP) .setMapKeyType(keyType) .setMapValueType(valueType) @@ -663,8 +699,7 @@ public static final FieldType map(FieldType keyType, FieldType valueType) { /** @deprecated Set the nullability on the valueType instead */ @Deprecated - public static final FieldType map( - FieldType keyType, FieldType valueType, boolean valueTypeNullable) { + public static FieldType map(FieldType keyType, FieldType valueType, boolean valueTypeNullable) { return FieldType.forTypeName(TypeName.MAP) .setMapKeyType(keyType) .setMapValueType(valueType.withNullable(valueTypeNullable)) @@ -672,13 +707,12 @@ public static final FieldType map( } /** Create a map type for the given key and value types. */ - public static final FieldType row(Schema schema) { + public static FieldType row(Schema schema) { return FieldType.forTypeName(TypeName.ROW).setRowSchema(schema).build(); } /** Creates a logical type based on a primitive field type. */ - public static final FieldType logicalType( - LogicalType logicalType) { + public static FieldType logicalType(LogicalType logicalType) { return FieldType.forTypeName(TypeName.LOGICAL_TYPE).setLogicalType(logicalType).build(); } @@ -874,6 +908,9 @@ public abstract static class Field implements Serializable { /** Returns the fields {@link FieldType}. */ public abstract FieldType getType(); + /** Returns the fields {@link Options}. */ + public abstract Options getOptions(); + public abstract Builder toBuilder(); /** Builder for {@link Field}. */ @@ -885,6 +922,13 @@ public abstract static class Builder { public abstract Builder setType(FieldType fieldType); + public abstract Builder setOptions(Options options); + + public Builder setOptions(Options.Builder optionsBuilder) { + setOptions(optionsBuilder.build()); + return this; + } + public abstract Field build(); } @@ -894,6 +938,7 @@ public static Field of(String name, FieldType fieldType) { .setName(name) .setDescription("") .setType(fieldType) + .setOptions(Options.none()) .build(); } @@ -903,6 +948,7 @@ public static Field nullable(String name, FieldType fieldType) { .setName(name) .setDescription("") .setType(fieldType.withNullable(true)) + .setOptions(Options.none()) .build(); } @@ -926,6 +972,16 @@ public Field withNullable(boolean isNullable) { return toBuilder().setType(getType().withNullable(isNullable)).build(); } + /** Returns a copy of the Field with the options set. */ + public Field withOptions(Options options) { + return toBuilder().setOptions(getOptions().toBuilder().addOptions(options).build()).build(); + } + + /** Returns a copy of the Field with the options set. */ + public Field withOptions(Options.Builder optionsBuilder) { + return withOptions(optionsBuilder.build()); + } + @Override public boolean equals(Object o) { if (!(o instanceof Field)) { @@ -934,7 +990,8 @@ public boolean equals(Object o) { Field other = (Field) o; return Objects.equals(getName(), other.getName()) && Objects.equals(getDescription(), other.getDescription()) - && Objects.equals(getType(), other.getType()); + && Objects.equals(getType(), other.getType()) + && Objects.equals(getOptions(), other.getOptions()); } /** Returns true if two fields are equal, ignoring name and description. */ @@ -953,6 +1010,209 @@ public int hashCode() { } } + public static class Options implements Serializable { + private final Map options; + + @Override + public String toString() { + TreeMap sorted = new TreeMap(options); + return "{" + sorted + '}'; + } + + Map getAllOptions() { + return options; + } + + public Set getOptionNames() { + return options.keySet(); + } + + public boolean hasOptions() { + return options.size() > 0; + } + + public boolean hasOption(String name) { + return options.containsKey(name); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Options options1 = (Options) o; + if (!options.keySet().equals(options1.options.keySet())) { + return false; + } + for (Map.Entry optionEntry : options.entrySet()) { + Option thisOption = optionEntry.getValue(); + Option otherOption = options1.options.get(optionEntry.getKey()); + if (!thisOption.equals(otherOption)) { + return false; + } + } + return true; + } + + @Override + public int hashCode() { + return Objects.hash(options); + } + + static class Option implements Serializable { + Option(FieldType type, Object value) { + this.type = type; + this.value = value; + } + + private FieldType type; + private Object value; + + @SuppressWarnings("TypeParameterUnusedInFormals") + T getValue() { + return (T) value; + } + + FieldType getType() { + return type; + } + + @Override + public String toString() { + return "Option{type=" + type + ", value=" + value + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Option option = (Option) o; + return Objects.equals(type, option.type) + && Row.Equals.deepEquals(value, option.value, type); + } + + @Override + public int hashCode() { + return Row.Equals.deepHashCode(value, type); + } + } + + public static class Builder { + private Map options; + + Builder(Map init) { + this.options = new HashMap<>(init); + } + + Builder() { + this(new HashMap<>()); + } + + public Builder setRowOption(String optionName, Row value) { + setOption(optionName, FieldType.row(value.getSchema()), value); + return this; + } + + public Builder setOption(String optionName, FieldType fieldType, Object value) { + if (value == null) { + if (fieldType.getNullable()) { + options.put(optionName, new Option(fieldType, null)); + } else { + throw new IllegalArgumentException( + String.format("Option %s is not nullable", optionName)); + } + } else { + options.put( + optionName, new Option(fieldType, verifyFieldValue(value, fieldType, optionName))); + } + return this; + } + + public Options build() { + return new Options(this.options); + } + + public Builder addOptions(Options options) { + this.options.putAll(options.options); + return this; + } + } + + Options(Map options) { + this.options = options; + } + + Options() { + this.options = new HashMap<>(); + } + + Options.Builder toBuilder() { + return new Builder(new HashMap<>(this.options)); + } + + public static Options.Builder builder() { + return new Builder(); + } + + public static Options none() { + return new Options(); + } + + /** Get the value of an option. If the option is not found null is returned. */ + @SuppressWarnings("TypeParameterUnusedInFormals") + @Nullable + public T getValue(String optionName) { + Option option = options.get(optionName); + if (option != null) { + return option.getValue(); + } + throw new IllegalArgumentException( + String.format("No option found with name %s.", optionName)); + } + + /** Get the value of an option. If the option is not found null is returned. */ + @Nullable + public T getValue(String optionName, Class valueClass) { + return getValue(optionName); + } + + /** Get the value of an option. If the option is not found the default value is returned. */ + @Nullable + public T getValueOrDefault(String optionName, T defaultValue) { + Option option = options.get(optionName); + if (option != null) { + return option.getValue(); + } + return defaultValue; + } + + /** Get the type of an option. */ + @Nullable + public FieldType getType(String optionName) { + Option option = options.get(optionName); + if (option != null) { + return option.getType(); + } + throw new IllegalArgumentException( + String.format("No option found with name %s.", optionName)); + } + + public static Options.Builder setOption(String optionName, FieldType fieldType, Object value) { + return Options.builder().setOption(optionName, fieldType, value); + } + + public static Options.Builder setRowOption(String optionName, Row value) { + return Options.builder().setRowOption(optionName, value); + } + } + /** Collects a stream of {@link Field}s into a {@link Schema}. */ public static Collector, Schema> toSchema() { return Collector.of( @@ -986,10 +1246,8 @@ public Field getField(String name) { /** Find the index of a given field. */ public int indexOf(String fieldName) { Integer index = fieldIndices.get(fieldName); - if (index == null) { - throw new IllegalArgumentException( - String.format("Cannot find field %s in schema %s", fieldName, this)); - } + Preconditions.checkArgument( + index != null, String.format("Cannot find field %s in schema %s", fieldName, this)); return index; } @@ -1001,9 +1259,7 @@ public boolean hasField(String fieldName) { /** Return the name of field by index. */ public String nameOf(int fieldIndex) { String name = fieldIndices.inverse().get(fieldIndex); - if (name == null) { - throw new IllegalArgumentException(String.format("Cannot find field %d", fieldIndex)); - } + Preconditions.checkArgument(name != null, String.format("Cannot find field %d", fieldIndex)); return name; } @@ -1011,4 +1267,8 @@ public String nameOf(int fieldIndex) { public int getFieldCount() { return getFields().size(); } + + public Options getOptions() { + return this.options; + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java index 0d19b65d2af0..d311643b7e62 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java @@ -17,10 +17,15 @@ */ package org.apache.beam.sdk.schemas; +import java.util.ArrayList; +import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.UUID; +import java.util.stream.Collectors; import org.apache.beam.model.pipeline.v1.SchemaApi; import org.apache.beam.model.pipeline.v1.SchemaApi.ArrayTypeValue; +import org.apache.beam.model.pipeline.v1.SchemaApi.AtomicTypeValue; import org.apache.beam.model.pipeline.v1.SchemaApi.FieldValue; import org.apache.beam.model.pipeline.v1.SchemaApi.IterableTypeValue; import org.apache.beam.model.pipeline.v1.SchemaApi.MapTypeEntry; @@ -57,6 +62,7 @@ public static SchemaApi.Schema schemaToProto(Schema schema, boolean serializeLog serializeLogicalType); builder.addFields(protoField); } + builder.addAllOptions(optionsToProto(schema.getOptions())); return builder.build(); } @@ -68,6 +74,7 @@ private static SchemaApi.Field fieldToProto( .setType(fieldTypeToProto(field.getType(), serializeLogicalType)) .setId(fieldId) .setEncodingPosition(position) + .addAllOptions(optionsToProto(field.getOptions())) .build(); } @@ -110,7 +117,7 @@ private static SchemaApi.FieldType fieldTypeToProto( .setArgumentType( fieldTypeToProto(logicalType.getArgumentType(), serializeLogicalType)) .setArgument( - rowFieldToProto(logicalType.getArgumentType(), logicalType.getArgument())) + fieldValueToProto(logicalType.getArgumentType(), logicalType.getArgument())) .setRepresentation( fieldTypeToProto(logicalType.getBaseType(), serializeLogicalType)) // TODO(BEAM-7855): "javasdk" types should only be a last resort. Types defined in @@ -172,7 +179,7 @@ private static SchemaApi.FieldType fieldTypeToProto( return builder.build(); } - public static Schema fromProto(SchemaApi.Schema protoSchema) { + public static Schema schemaFromProto(SchemaApi.Schema protoSchema) { Schema.Builder builder = Schema.builder(); Map encodingLocationMap = Maps.newHashMap(); for (SchemaApi.Field protoField : protoSchema.getFieldsList()) { @@ -180,17 +187,18 @@ public static Schema fromProto(SchemaApi.Schema protoSchema) { builder.addField(field); encodingLocationMap.put(protoField.getName(), protoField.getEncodingPosition()); } + builder.setOptions(optionsFromProto(protoSchema.getOptionsList())); Schema schema = builder.build(); schema.setEncodingPositions(encodingLocationMap); if (!protoSchema.getId().isEmpty()) { schema.setUUID(UUID.fromString(protoSchema.getId())); } - return schema; } private static Field fieldFromProto(SchemaApi.Field protoField) { return Field.of(protoField.getName(), fieldTypeFromProto(protoField.getType())) + .withOptions(optionsFromProto(protoField.getOptionsList())) .withDescription(protoField.getDescription()); } @@ -233,7 +241,7 @@ private static FieldType fieldTypeFromProtoWithoutNullable(SchemaApi.FieldType p "Encountered unknown AtomicType: " + protoFieldType.getAtomicType()); } case ROW_TYPE: - return FieldType.row(fromProto(protoFieldType.getRowType().getSchema())); + return FieldType.row(schemaFromProto(protoFieldType.getRowType().getSchema())); case ARRAY_TYPE: return FieldType.array(fieldTypeFromProto(protoFieldType.getArrayType().getElementType())); case ITERABLE_TYPE: @@ -268,12 +276,21 @@ private static FieldType fieldTypeFromProtoWithoutNullable(SchemaApi.FieldType p public static SchemaApi.Row rowToProto(Row row) { SchemaApi.Row.Builder builder = SchemaApi.Row.newBuilder(); for (int i = 0; i < row.getFieldCount(); ++i) { - builder.addValues(rowFieldToProto(row.getSchema().getField(i).getType(), row.getValue(i))); + builder.addValues(fieldValueToProto(row.getSchema().getField(i).getType(), row.getValue(i))); + } + return builder.build(); + } + + public static Object rowFromProto(SchemaApi.Row row, FieldType fieldType) { + Row.Builder builder = Row.withSchema(fieldType.getRowSchema()); + for (int i = 0; i < row.getValuesCount(); ++i) { + builder.addValue( + fieldValueFromProto(fieldType.getRowSchema().getField(i).getType(), row.getValues(i))); } return builder.build(); } - private static SchemaApi.FieldValue rowFieldToProto(FieldType fieldType, Object value) { + static SchemaApi.FieldValue fieldValueToProto(FieldType fieldType, Object value) { FieldValue.Builder builder = FieldValue.newBuilder(); switch (fieldType.getTypeName()) { case ARRAY: @@ -299,59 +316,146 @@ private static SchemaApi.FieldValue rowFieldToProto(FieldType fieldType, Object } } + static Object fieldValueFromProto(FieldType fieldType, SchemaApi.FieldValue value) { + switch (fieldType.getTypeName()) { + case ARRAY: + return arrayValueFromProto(fieldType.getCollectionElementType(), value.getArrayValue()); + case ITERABLE: + return iterableValueFromProto( + fieldType.getCollectionElementType(), value.getIterableValue()); + case MAP: + return mapFromProto( + fieldType.getMapKeyType(), fieldType.getMapValueType(), value.getMapValue()); + case ROW: + return rowFromProto(value.getRowValue(), fieldType); + case LOGICAL_TYPE: + default: + return primitiveFromProto(fieldType, value.getAtomicValue()); + } + } + private static SchemaApi.ArrayTypeValue arrayValueToProto( FieldType elementType, Iterable values) { return ArrayTypeValue.newBuilder() - .addAllElement(Iterables.transform(values, e -> rowFieldToProto(elementType, e))) + .addAllElement(Iterables.transform(values, e -> fieldValueToProto(elementType, e))) .build(); } + private static Iterable arrayValueFromProto( + FieldType elementType, SchemaApi.ArrayTypeValue values) { + return values.getElementList().stream() + .map(e -> fieldValueFromProto(elementType, e)) + .collect(Collectors.toList()); + } + private static SchemaApi.IterableTypeValue iterableValueToProto( FieldType elementType, Iterable values) { return IterableTypeValue.newBuilder() - .addAllElement(Iterables.transform(values, e -> rowFieldToProto(elementType, e))) + .addAllElement(Iterables.transform(values, e -> fieldValueToProto(elementType, e))) .build(); } + private static Object iterableValueFromProto(FieldType elementType, IterableTypeValue values) { + return values.getElementList().stream() + .map(e -> fieldValueFromProto(elementType, e)) + .collect(Collectors.toList()); + } + private static SchemaApi.MapTypeValue mapToProto( FieldType keyType, FieldType valueType, Map map) { MapTypeValue.Builder builder = MapTypeValue.newBuilder(); for (Map.Entry entry : map.entrySet()) { MapTypeEntry mapProtoEntry = MapTypeEntry.newBuilder() - .setKey(rowFieldToProto(keyType, entry.getKey())) - .setValue(rowFieldToProto(valueType, entry.getValue())) + .setKey(fieldValueToProto(keyType, entry.getKey())) + .setValue(fieldValueToProto(valueType, entry.getValue())) .build(); builder.addEntries(mapProtoEntry); } return builder.build(); } - private static SchemaApi.AtomicTypeValue primitiveRowFieldToProto( - FieldType fieldType, Object value) { + private static Object mapFromProto( + FieldType mapKeyType, FieldType mapValueType, MapTypeValue mapValue) { + return mapValue.getEntriesList().stream() + .collect( + Collectors.toMap( + entry -> fieldValueFromProto(mapKeyType, entry.getKey()), + entry -> fieldValueFromProto(mapValueType, entry.getValue()))); + } + + private static AtomicTypeValue primitiveRowFieldToProto(FieldType fieldType, Object value) { switch (fieldType.getTypeName()) { case BYTE: - return SchemaApi.AtomicTypeValue.newBuilder().setByte((int) value).build(); + return AtomicTypeValue.newBuilder().setByte((byte) value).build(); case INT16: - return SchemaApi.AtomicTypeValue.newBuilder().setInt16((int) value).build(); + return AtomicTypeValue.newBuilder().setInt16((short) value).build(); case INT32: - return SchemaApi.AtomicTypeValue.newBuilder().setInt32((int) value).build(); + return AtomicTypeValue.newBuilder().setInt32((int) value).build(); case INT64: - return SchemaApi.AtomicTypeValue.newBuilder().setInt64((long) value).build(); + return AtomicTypeValue.newBuilder().setInt64((long) value).build(); case FLOAT: - return SchemaApi.AtomicTypeValue.newBuilder().setFloat((float) value).build(); + return AtomicTypeValue.newBuilder().setFloat((float) value).build(); case DOUBLE: - return SchemaApi.AtomicTypeValue.newBuilder().setDouble((double) value).build(); + return AtomicTypeValue.newBuilder().setDouble((double) value).build(); case STRING: - return SchemaApi.AtomicTypeValue.newBuilder().setString((String) value).build(); + return AtomicTypeValue.newBuilder().setString((String) value).build(); case BOOLEAN: - return SchemaApi.AtomicTypeValue.newBuilder().setBoolean((boolean) value).build(); + return AtomicTypeValue.newBuilder().setBoolean((boolean) value).build(); case BYTES: - return SchemaApi.AtomicTypeValue.newBuilder() - .setBytes(ByteString.copyFrom((byte[]) value)) - .build(); + return AtomicTypeValue.newBuilder().setBytes(ByteString.copyFrom((byte[]) value)).build(); default: throw new RuntimeException("FieldType unexpected " + fieldType.getTypeName()); } } + + private static Object primitiveFromProto(FieldType fieldType, AtomicTypeValue value) { + switch (fieldType.getTypeName()) { + case BYTE: + return (byte) value.getByte(); + case INT16: + return (short) value.getInt16(); + case INT32: + return value.getInt32(); + case INT64: + return value.getInt64(); + case FLOAT: + return value.getFloat(); + case DOUBLE: + return value.getDouble(); + case STRING: + return value.getString(); + case BOOLEAN: + return value.getBoolean(); + case BYTES: + return value.getBytes().toByteArray(); + default: + throw new RuntimeException("FieldType unexpected " + fieldType.getTypeName()); + } + } + + private static List optionsToProto(Schema.Options options) { + List protoOptions = new ArrayList<>(); + for (String name : options.getOptionNames()) { + protoOptions.add( + SchemaApi.Option.newBuilder() + .setName(name) + .setType(fieldTypeToProto(Objects.requireNonNull(options.getType(name)), false)) + .setValue( + fieldValueToProto( + Objects.requireNonNull(options.getType(name)), options.getValue(name))) + .build()); + } + return protoOptions; + } + + private static Schema.Options optionsFromProto(List protoOptions) { + Schema.Options.Builder optionBuilder = Schema.Options.builder(); + for (SchemaApi.Option protoOption : protoOptions) { + FieldType fieldType = fieldTypeFromProto(protoOption.getType()); + optionBuilder.setOption( + protoOption.getName(), fieldType, fieldValueFromProto(fieldType, protoOption.getValue())); + } + return optionBuilder.build(); + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index 9c35657c7ce4..8f4cf6243874 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -17,12 +17,12 @@ */ package org.apache.beam.sdk.values; +import static org.apache.beam.sdk.values.SchemaVerification.verifyRowValues; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; import java.io.Serializable; import java.math.BigDecimal; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -30,7 +30,6 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Objects; import java.util.stream.Collector; import javax.annotation.Nullable; @@ -39,17 +38,13 @@ import org.apache.beam.sdk.schemas.Factory; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.LogicalType; import org.apache.beam.sdk.schemas.Schema.TypeName; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.joda.time.DateTime; -import org.joda.time.Instant; import org.joda.time.ReadableDateTime; import org.joda.time.ReadableInstant; -import org.joda.time.base.AbstractInstant; /** * {@link Row} is an immutable tuple-like schema to represent one element in a {@link PCollection}. @@ -434,7 +429,7 @@ public static boolean deepEquals(Object a, Object b, Schema.FieldType fieldType) } } - static int deepHashCode(Object a, Schema.FieldType fieldType) { + public static int deepHashCode(Object a, Schema.FieldType fieldType) { if (a == null) { return 0; } else if (fieldType.getTypeName() == TypeName.LOGICAL_TYPE) { @@ -614,230 +609,13 @@ public Builder withFieldValueGetters( return this; } - private List verify(Schema schema, List values) { - List verifiedValues = Lists.newArrayListWithCapacity(values.size()); - if (schema.getFieldCount() != values.size()) { - throw new IllegalArgumentException( - String.format( - "Field count in Schema (%s) (%d) and values (%s) (%d) must match", - schema.getFieldNames(), schema.getFieldCount(), values, values.size())); - } - for (int i = 0; i < values.size(); ++i) { - Object value = values.get(i); - Schema.Field field = schema.getField(i); - if (value == null) { - if (!field.getType().getNullable()) { - throw new IllegalArgumentException( - String.format("Field %s is not nullable", field.getName())); - } - verifiedValues.add(null); - } else { - verifiedValues.add(verify(value, field.getType(), field.getName())); - } - } - return verifiedValues; - } - - private Object verify(Object value, FieldType type, String fieldName) { - if (TypeName.ARRAY.equals(type.getTypeName())) { - return verifyArray(value, type.getCollectionElementType(), fieldName); - } else if (TypeName.ITERABLE.equals(type.getTypeName())) { - return verifyIterable(value, type.getCollectionElementType(), fieldName); - } - if (TypeName.MAP.equals(type.getTypeName())) { - return verifyMap(value, type.getMapKeyType(), type.getMapValueType(), fieldName); - } else if (TypeName.ROW.equals(type.getTypeName())) { - return verifyRow(value, fieldName); - } else if (TypeName.LOGICAL_TYPE.equals(type.getTypeName())) { - return verifyLogicalType(value, type.getLogicalType(), fieldName); - } else { - return verifyPrimitiveType(value, type.getTypeName(), fieldName); - } - } - - private Object verifyLogicalType(Object value, LogicalType logicalType, String fieldName) { - return verify(logicalType.toBaseType(value), logicalType.getBaseType(), fieldName); - } - - private List verifyArray( - Object value, FieldType collectionElementType, String fieldName) { - boolean collectionElementTypeNullable = collectionElementType.getNullable(); - if (!(value instanceof Collection)) { - throw new IllegalArgumentException( - String.format( - "For field name %s and array type expected Collection class. Instead " - + "class type was %s.", - fieldName, value.getClass())); - } - Collection valueCollection = (Collection) value; - List verifiedList = Lists.newArrayListWithCapacity(valueCollection.size()); - for (Object listValue : valueCollection) { - if (listValue == null) { - if (!collectionElementTypeNullable) { - throw new IllegalArgumentException( - String.format( - "%s is not nullable in Array field %s", collectionElementType, fieldName)); - } - verifiedList.add(null); - } else { - verifiedList.add(verify(listValue, collectionElementType, fieldName)); - } - } - return verifiedList; - } - - private Iterable verifyIterable( - Object value, FieldType collectionElementType, String fieldName) { - boolean collectionElementTypeNullable = collectionElementType.getNullable(); - if (!(value instanceof Iterable)) { - throw new IllegalArgumentException( - String.format( - "For field name %s and iterable type expected class extending Iterable. Instead " - + "class type was %s.", - fieldName, value.getClass())); - } - Iterable valueIterable = (Iterable) value; - for (Object listValue : valueIterable) { - if (listValue == null) { - if (!collectionElementTypeNullable) { - throw new IllegalArgumentException( - String.format( - "%s is not nullable in Array field %s", collectionElementType, fieldName)); - } - } else { - verify(listValue, collectionElementType, fieldName); - } - } - return valueIterable; - } - - private Map verifyMap( - Object value, FieldType keyType, FieldType valueType, String fieldName) { - boolean valueTypeNullable = valueType.getNullable(); - if (!(value instanceof Map)) { - throw new IllegalArgumentException( - String.format( - "For field name %s and map type expected Map class. Instead " - + "class type was %s.", - fieldName, value.getClass())); - } - Map valueMap = (Map) value; - Map verifiedMap = Maps.newHashMapWithExpectedSize(valueMap.size()); - for (Entry kv : valueMap.entrySet()) { - if (kv.getValue() == null) { - if (!valueTypeNullable) { - throw new IllegalArgumentException( - String.format("%s is not nullable in Map field %s", valueType, fieldName)); - } - verifiedMap.put(verify(kv.getKey(), keyType, fieldName), null); - } else { - verifiedMap.put( - verify(kv.getKey(), keyType, fieldName), verify(kv.getValue(), valueType, fieldName)); - } - } - return verifiedMap; - } - - private Row verifyRow(Object value, String fieldName) { - if (!(value instanceof Row)) { - throw new IllegalArgumentException( - String.format( - "For field name %s expected Row type. " + "Instead class type was %s.", - fieldName, value.getClass())); - } - // No need to recursively validate the nested Row, since there's no way to build the - // Row object without it validating. - return (Row) value; - } - - private Object verifyPrimitiveType(Object value, TypeName type, String fieldName) { - if (type.isDateType()) { - return verifyDateTime(value, fieldName); - } else { - switch (type) { - case BYTE: - if (value instanceof Byte) { - return value; - } - break; - case BYTES: - if (value instanceof ByteBuffer) { - return ((ByteBuffer) value).array(); - } else if (value instanceof byte[]) { - return (byte[]) value; - } - break; - case INT16: - if (value instanceof Short) { - return value; - } - break; - case INT32: - if (value instanceof Integer) { - return value; - } - break; - case INT64: - if (value instanceof Long) { - return value; - } - break; - case DECIMAL: - if (value instanceof BigDecimal) { - return value; - } - break; - case FLOAT: - if (value instanceof Float) { - return value; - } - break; - case DOUBLE: - if (value instanceof Double) { - return value; - } - break; - case STRING: - if (value instanceof String) { - return value; - } - break; - case BOOLEAN: - if (value instanceof Boolean) { - return value; - } - break; - default: - // Shouldn't actually get here, but we need this case to satisfy linters. - throw new IllegalArgumentException( - String.format("Not a primitive type for field name %s: %s", fieldName, type)); - } - throw new IllegalArgumentException( - String.format( - "For field name %s and type %s found incorrect class type %s", - fieldName, type, value.getClass())); - } - } - - private Instant verifyDateTime(Object value, String fieldName) { - // We support the following classes for datetimes. - if (value instanceof AbstractInstant) { - return ((AbstractInstant) value).toInstant(); - } else { - throw new IllegalArgumentException( - String.format( - "For field name %s and DATETIME type got unexpected class %s ", - fieldName, value.getClass())); - } - } - public Row build() { checkNotNull(schema); if (!this.values.isEmpty() && fieldValueGetterFactory != null) { throw new IllegalArgumentException("Cannot specify both values and getters."); } if (!this.values.isEmpty()) { - List storageValues = attached ? this.values : verify(schema, this.values); + List storageValues = attached ? this.values : verifyRowValues(schema, this.values); checkState(getterTarget == null, "withGetterTarget requires getters."); return new RowWithStorage(schema, storageValues); } else if (fieldValueGetterFactory != null) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java new file mode 100644 index 000000000000..e06e7028bcb2 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.values; + +import java.io.Serializable; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.Schema.LogicalType; +import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.joda.time.Instant; +import org.joda.time.base.AbstractInstant; + +@Experimental +public abstract class SchemaVerification implements Serializable { + + static List verifyRowValues(Schema schema, List values) { + List verifiedValues = Lists.newArrayListWithCapacity(values.size()); + if (schema.getFieldCount() != values.size()) { + throw new IllegalArgumentException( + String.format( + "Field count in Schema (%s) (%d) and values (%s) (%d) must match", + schema.getFieldNames(), schema.getFieldCount(), values, values.size())); + } + for (int i = 0; i < values.size(); ++i) { + Object value = values.get(i); + Schema.Field field = schema.getField(i); + if (value == null) { + if (!field.getType().getNullable()) { + throw new IllegalArgumentException( + String.format("Field %s is not nullable", field.getName())); + } + verifiedValues.add(null); + } else { + verifiedValues.add(verifyFieldValue(value, field.getType(), field.getName())); + } + } + return verifiedValues; + } + + public static Object verifyFieldValue(Object value, FieldType type, String fieldName) { + if (TypeName.ARRAY.equals(type.getTypeName())) { + return verifyArray(value, type.getCollectionElementType(), fieldName); + } else if (TypeName.ITERABLE.equals(type.getTypeName())) { + return verifyIterable(value, type.getCollectionElementType(), fieldName); + } + if (TypeName.MAP.equals(type.getTypeName())) { + return verifyMap(value, type.getMapKeyType(), type.getMapValueType(), fieldName); + } else if (TypeName.ROW.equals(type.getTypeName())) { + return verifyRow(value, fieldName); + } else if (TypeName.LOGICAL_TYPE.equals(type.getTypeName())) { + return verifyLogicalType(value, type.getLogicalType(), fieldName); + } else { + return verifyPrimitiveType(value, type.getTypeName(), fieldName); + } + } + + private static Object verifyLogicalType(Object value, LogicalType logicalType, String fieldName) { + return verifyFieldValue(logicalType.toBaseType(value), logicalType.getBaseType(), fieldName); + } + + private static List verifyArray( + Object value, FieldType collectionElementType, String fieldName) { + boolean collectionElementTypeNullable = collectionElementType.getNullable(); + if (!(value instanceof List)) { + throw new IllegalArgumentException( + String.format( + "For field name %s and array type expected List class. Instead " + + "class type was %s.", + fieldName, value.getClass())); + } + List valueList = (List) value; + List verifiedList = Lists.newArrayListWithCapacity(valueList.size()); + for (Object listValue : valueList) { + if (listValue == null) { + if (!collectionElementTypeNullable) { + throw new IllegalArgumentException( + String.format( + "%s is not nullable in Array field %s", collectionElementType, fieldName)); + } + verifiedList.add(null); + } else { + verifiedList.add(verifyFieldValue(listValue, collectionElementType, fieldName)); + } + } + return verifiedList; + } + + private static Iterable verifyIterable( + Object value, FieldType collectionElementType, String fieldName) { + boolean collectionElementTypeNullable = collectionElementType.getNullable(); + if (!(value instanceof Iterable)) { + throw new IllegalArgumentException( + String.format( + "For field name %s and iterable type expected class extending Iterable. Instead " + + "class type was %s.", + fieldName, value.getClass())); + } + Iterable valueList = (Iterable) value; + for (Object listValue : valueList) { + if (listValue == null) { + if (!collectionElementTypeNullable) { + throw new IllegalArgumentException( + String.format( + "%s is not nullable in Array field %s", collectionElementType, fieldName)); + } + } else { + verifyFieldValue(listValue, collectionElementType, fieldName); + } + } + return valueList; + } + + private static Map verifyMap( + Object value, FieldType keyType, FieldType valueType, String fieldName) { + boolean valueTypeNullable = valueType.getNullable(); + if (!(value instanceof Map)) { + throw new IllegalArgumentException( + String.format( + "For field name %s and map type expected Map class. Instead " + "class type was %s.", + fieldName, value.getClass())); + } + Map valueMap = (Map) value; + Map verifiedMap = Maps.newHashMapWithExpectedSize(valueMap.size()); + for (Entry kv : valueMap.entrySet()) { + if (kv.getValue() == null) { + if (!valueTypeNullable) { + throw new IllegalArgumentException( + String.format("%s is not nullable in Map field %s", valueType, fieldName)); + } + verifiedMap.put(verifyFieldValue(kv.getKey(), keyType, fieldName), null); + } else { + verifiedMap.put( + verifyFieldValue(kv.getKey(), keyType, fieldName), + verifyFieldValue(kv.getValue(), valueType, fieldName)); + } + } + return verifiedMap; + } + + private static Row verifyRow(Object value, String fieldName) { + if (!(value instanceof Row)) { + throw new IllegalArgumentException( + String.format( + "For field name %s expected Row type. " + "Instead class type was %s.", + fieldName, value.getClass())); + } + // No need to recursively validate the nested Row, since there's no way to build the + // Row object without it validating. + return (Row) value; + } + + private static Object verifyPrimitiveType(Object value, TypeName type, String fieldName) { + if (type.isDateType()) { + return verifyDateTime(value, fieldName); + } else { + switch (type) { + case BYTE: + if (value instanceof Byte) { + return value; + } + break; + case BYTES: + if (value instanceof ByteBuffer) { + return ((ByteBuffer) value).array(); + } else if (value instanceof byte[]) { + return (byte[]) value; + } + break; + case INT16: + if (value instanceof Short) { + return value; + } + break; + case INT32: + if (value instanceof Integer) { + return value; + } + break; + case INT64: + if (value instanceof Long) { + return value; + } + break; + case DECIMAL: + if (value instanceof BigDecimal) { + return value; + } + break; + case FLOAT: + if (value instanceof Float) { + return value; + } + break; + case DOUBLE: + if (value instanceof Double) { + return value; + } + break; + case STRING: + if (value instanceof String) { + return value; + } + break; + case BOOLEAN: + if (value instanceof Boolean) { + return value; + } + break; + default: + // Shouldn't actually get here, but we need this case to satisfy linters. + throw new IllegalArgumentException( + String.format("Not a primitive type for field name %s: %s", fieldName, type)); + } + throw new IllegalArgumentException( + String.format( + "For field name %s and type %s found incorrect class type %s", + fieldName, type, value.getClass())); + } + } + + private static Instant verifyDateTime(Object value, String fieldName) { + // We support the following classes for datetimes. + if (value instanceof AbstractInstant) { + return ((AbstractInstant) value).toInstant(); + } else { + throw new IllegalArgumentException( + String.format( + "For field name %s and DATETIME type got unexpected class %s ", + fieldName, value.getClass())); + } + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaOptionsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaOptionsTest.java new file mode 100644 index 000000000000..f3e3685f98c8 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaOptionsTest.java @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.schemas; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.values.Row; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +/** Unit tests for {@link Schema.Options}. */ +public class SchemaOptionsTest { + + private static final String OPTION_NAME = "beam:test:field_i1"; + private static final String FIELD_NAME = "f_field"; + private static final Field FIELD = Field.of(FIELD_NAME, FieldType.STRING); + private static final Row TEST_ROW = + Row.withSchema( + Schema.builder() + .addField("field_one", FieldType.STRING) + .addField("field_two", FieldType.INT32) + .build()) + .addValue("value") + .addValue(42) + .build(); + + private static final Map TEST_MAP = new HashMap<>(); + + static { + TEST_MAP.put(1, "one"); + TEST_MAP.put(2, "two"); + } + + private static final List TEST_LIST = new ArrayList<>(); + + static { + TEST_LIST.add("one"); + TEST_LIST.add("two"); + TEST_LIST.add("three"); + } + + private static final byte[] TEST_BYTES = new byte[] {0x42, 0x69, 0x00}; + + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testBooleanOption() { + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.BOOLEAN, true).build(); + assertEquals(true, options.getValue(OPTION_NAME)); + assertEquals(FieldType.BOOLEAN, options.getType(OPTION_NAME)); + } + + @Test + public void testInt16Option() { + Schema.Options options = + Schema.Options.setOption(OPTION_NAME, FieldType.INT16, (short) 12).build(); + assertEquals(Short.valueOf((short) 12), options.getValue(OPTION_NAME)); + assertEquals(FieldType.INT16, options.getType(OPTION_NAME)); + } + + @Test + public void testByteOption() { + Schema.Options options = + Schema.Options.setOption(OPTION_NAME, FieldType.BYTE, (byte) 12).build(); + assertEquals(Byte.valueOf((byte) 12), options.getValue(OPTION_NAME)); + assertEquals(FieldType.BYTE, options.getType(OPTION_NAME)); + } + + @Test + public void testInt32Option() { + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.INT32, 12).build(); + assertEquals(Integer.valueOf(12), options.getValue(OPTION_NAME)); + assertEquals(FieldType.INT32, options.getType(OPTION_NAME)); + } + + @Test + public void testInt64Option() { + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.INT64, 12L).build(); + assertEquals(Long.valueOf(12), options.getValue(OPTION_NAME)); + assertEquals(FieldType.INT64, options.getType(OPTION_NAME)); + } + + @Test + public void testFloatOption() { + Schema.Options options = + Schema.Options.setOption(OPTION_NAME, FieldType.FLOAT, (float) 12.0).build(); + assertEquals(Float.valueOf((float) 12.0), options.getValue(OPTION_NAME)); + assertEquals(FieldType.FLOAT, options.getType(OPTION_NAME)); + } + + @Test + public void testDoubleOption() { + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.DOUBLE, 12.0).build(); + assertEquals(Double.valueOf(12.0), options.getValue(OPTION_NAME)); + assertEquals(FieldType.DOUBLE, options.getType(OPTION_NAME)); + } + + @Test + public void testStringOption() { + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.STRING, "foo").build(); + assertEquals("foo", options.getValue(OPTION_NAME)); + assertEquals(FieldType.STRING, options.getType(OPTION_NAME)); + } + + @Test + public void testBytesOption() { + byte[] bytes = new byte[] {0x42, 0x69, 0x00}; + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.BYTES, bytes).build(); + assertEquals(bytes, options.getValue(OPTION_NAME)); + assertEquals(FieldType.BYTES, options.getType(OPTION_NAME)); + } + + @Test + public void testDateTimeOption() { + DateTime now = DateTime.now().withZone(DateTimeZone.UTC); + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.DATETIME, now).build(); + assertEquals(now, options.getValue(OPTION_NAME)); + assertEquals(FieldType.DATETIME, options.getType(OPTION_NAME)); + } + + @Test + public void testArrayOfIntegerOption() { + Schema.Options options = + Schema.Options.setOption(OPTION_NAME, FieldType.array(FieldType.STRING), TEST_LIST).build(); + assertEquals(TEST_LIST, options.getValue(OPTION_NAME)); + assertEquals(FieldType.array(FieldType.STRING), options.getType(OPTION_NAME)); + } + + @Test + public void testMapOfIntegerStringOption() { + Schema.Options options = + Schema.Options.setOption( + OPTION_NAME, FieldType.map(FieldType.INT32, FieldType.STRING), TEST_MAP) + .build(); + assertEquals(TEST_MAP, options.getValue(OPTION_NAME)); + assertEquals(FieldType.map(FieldType.INT32, FieldType.STRING), options.getType(OPTION_NAME)); + } + + @Test + public void testRowOption() { + Schema.Options options = Schema.Options.setRowOption(OPTION_NAME, TEST_ROW).build(); + assertEquals(TEST_ROW, options.getValue(OPTION_NAME)); + assertEquals(FieldType.row(TEST_ROW.getSchema()), options.getType(OPTION_NAME)); + } + + @Test(expected = IllegalArgumentException.class) + public void testNotNullableOptionSetNull() { + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.STRING, null).build(); + } + + @Test + public void testNullableOptionSetNull() { + Schema.Options options = + Schema.Options.setOption(OPTION_NAME, FieldType.STRING.withNullable(true), null).build(); + assertNull(options.getValue(OPTION_NAME)); + assertEquals(FieldType.STRING.withNullable(true), options.getType(OPTION_NAME)); + } + + @Test(expected = IllegalArgumentException.class) + public void testGetValueNoOption() { + Schema.Options options = Schema.Options.none(); + options.getValue("foo"); + } + + @Test(expected = IllegalArgumentException.class) + public void testGetTypeNoOption() { + Schema.Options options = Schema.Options.none(); + options.getType("foo"); + } + + @Test + public void testGetValueOrDefault() { + Schema.Options options = Schema.Options.none(); + assertNull(options.getValueOrDefault("foo", null)); + } + + private Schema.Options.Builder setOptionsSet1() { + return setOptionsSet1(Schema.Options.builder()); + } + + private Schema.Options.Builder setOptionsSet1(Schema.Options.Builder builder) { + return builder + .setOption("field_option_boolean", FieldType.BOOLEAN, true) + .setOption("field_option_byte", FieldType.BYTE, (byte) 12) + .setOption("field_option_int16", FieldType.INT16, (short) 12) + .setOption("field_option_int32", FieldType.INT32, 12) + .setOption("field_option_int64", FieldType.INT64, 12L) + .setOption("field_option_string", FieldType.STRING, "foo"); + } + + private Schema.Options.Builder setOptionsSet2() { + return setOptionsSet2(Schema.Options.builder()); + } + + private void assertOptionSet1(Schema.Options options) { + assertEquals(true, options.getValue("field_option_boolean")); + assertEquals(12, (byte) options.getValue("field_option_byte")); + assertEquals(12, (short) options.getValue("field_option_int16")); + assertEquals(12, (int) options.getValue("field_option_int32")); + assertEquals(12L, (long) options.getValue("field_option_int64")); + assertEquals("foo", options.getValue("field_option_string")); + } + + private Schema.Options.Builder setOptionsSet2(Schema.Options.Builder builder) { + return builder + .setOption("field_option_bytes", FieldType.BYTES, new byte[] {0x42, 0x69, 0x00}) + .setOption("field_option_float", FieldType.FLOAT, (float) 12.0) + .setOption("field_option_double", FieldType.DOUBLE, 12.0) + .setOption("field_option_map", FieldType.map(FieldType.INT32, FieldType.STRING), TEST_MAP) + .setOption("field_option_array", FieldType.array(FieldType.STRING), TEST_LIST) + .setRowOption("field_option_row", TEST_ROW) + .setOption("field_option_value", FieldType.STRING, "other"); + } + + private void assertOptionSet2(Schema.Options options) { + assertArrayEquals(TEST_BYTES, options.getValue("field_option_bytes")); + assertEquals((float) 12.0, (float) options.getValue("field_option_float"), 0.1); + assertEquals(12.0, (double) options.getValue("field_option_double"), 0.1); + assertEquals(TEST_MAP, options.getValue("field_option_map")); + assertEquals(TEST_LIST, options.getValue("field_option_array")); + assertEquals(TEST_ROW, options.getValue("field_option_row")); + assertEquals("other", options.getValue("field_option_value")); + } + + @Test + public void testSchemaSetOptionWithBuilder() { + Schema schema = + Schema.builder() + .setOptions(setOptionsSet1(Schema.Options.builder())) + .addField(FIELD) + .build(); + assertOptionSet1(schema.getOptions()); + } + + @Test + public void testSchemaSetOption() { + Schema schema = + Schema.builder() + .setOptions(setOptionsSet1(Schema.Options.builder()).build()) + .addField(FIELD) + .build(); + assertOptionSet1(schema.getOptions()); + } + + @Test + public void testFieldWithOptionsBuilder() { + Schema schema = Schema.builder().addField(FIELD.withOptions(setOptionsSet1())).build(); + assertOptionSet1(schema.getField(FIELD_NAME).getOptions()); + } + + @Test + public void testFieldWithOptions() { + Schema schema = Schema.builder().addField(FIELD.withOptions(setOptionsSet1().build())).build(); + assertOptionSet1(schema.getField(FIELD_NAME).getOptions()); + } + + @Test + public void testFieldHasOptions() { + Schema schema = Schema.builder().addField(FIELD.withOptions(setOptionsSet1().build())).build(); + assertTrue(schema.getField(FIELD_NAME).getOptions().hasOptions()); + } + + @Test + public void testFieldHasNonOptions() { + Schema schema = Schema.builder().addField(FIELD).build(); + assertFalse(schema.getField(FIELD_NAME).getOptions().hasOptions()); + } + + @Test + public void testFieldHasOption() { + Schema schema = Schema.builder().addField(FIELD.withOptions(setOptionsSet1().build())).build(); + assertTrue(schema.getField(FIELD_NAME).getOptions().hasOption("field_option_byte")); + assertFalse(schema.getField(FIELD_NAME).getOptions().hasOption("foo_bar")); + } + + @Test + public void testFieldOptionNames() { + Schema schema = Schema.builder().addField(FIELD.withOptions(setOptionsSet1().build())).build(); + Set optionNames = schema.getField(FIELD_NAME).getOptions().getOptionNames(); + assertThat( + optionNames, + containsInAnyOrder( + "field_option_boolean", + "field_option_byte", + "field_option_int16", + "field_option_int32", + "field_option_int64", + "field_option_string")); + } + + @Test + public void testFieldBuilderSetOptions() { + Schema schema = + Schema.builder() + .addField(FIELD.toBuilder().setOptions(setOptionsSet1().build()).build()) + .build(); + assertOptionSet1(schema.getField(FIELD_NAME).getOptions()); + } + + @Test + public void testFieldBuilderSetOptionsBuilder() { + Schema schema = + Schema.builder().addField(FIELD.toBuilder().setOptions(setOptionsSet1()).build()).build(); + assertOptionSet1(schema.getField(FIELD_NAME).getOptions()); + } + + @Test + public void testAddOptions() { + Schema.Options options = + setOptionsSet1(Schema.Options.builder()) + .addOptions(setOptionsSet2(Schema.Options.builder()).build()) + .build(); + assertOptionSet1(options); + assertOptionSet2(options); + } +}