From d851e6cae36d9acee845ca24a4cf50ff5ce704b4 Mon Sep 17 00:00:00 2001 From: Alex Van Boxel Date: Sun, 1 Dec 2019 12:21:01 +0100 Subject: [PATCH] [BEAM-9035] BIP-1: Typed options for Row Schema and Field This is the first PR of a multipart commit: this ticket implements the basic infrastructure of options on row and field. Options in Beam Schema add extra context to fields and schema. In contracts to metadata, options would be added to fields, logical types and rows. Options are key/typed value combination. The type system is using the beam schema itself and the value can be any type that is supported by the beam schema, including row. --- model/pipeline/src/main/proto/schema.proto | 8 + .../core/construction/CoderTranslators.java | 2 +- .../core/construction/CommonCoderTest.java | 3 +- .../construction/SchemaTranslationTest.java | 81 +++- .../SchemaCoderCloudObjectTranslator.java | 2 +- .../org/apache/beam/sdk/schemas/Schema.java | 306 +++++++++++++-- .../beam/sdk/schemas/SchemaTranslation.java | 150 ++++++-- .../java/org/apache/beam/sdk/values/Row.java | 228 +----------- .../beam/sdk/values/SchemaVerification.java | 255 +++++++++++++ .../beam/sdk/schemas/SchemaOptionsTest.java | 347 ++++++++++++++++++ 10 files changed, 1107 insertions(+), 275 deletions(-) create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java create mode 100644 sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaOptionsTest.java diff --git a/model/pipeline/src/main/proto/schema.proto b/model/pipeline/src/main/proto/schema.proto index 2cf404e3341f..dcf75ca4a5ee 100644 --- a/model/pipeline/src/main/proto/schema.proto +++ b/model/pipeline/src/main/proto/schema.proto @@ -31,6 +31,7 @@ option java_outer_classname = "SchemaApi"; message Schema { repeated Field fields = 1; string id = 2; + repeated Option options = 3; } message Field { @@ -39,6 +40,7 @@ message Field { FieldType type = 3; int32 id = 4; int32 encoding_position = 5; + repeated Option options = 6; } message FieldType { @@ -91,6 +93,12 @@ message LogicalType { FieldValue argument = 5; } +message Option { + string name = 1; + FieldType type = 2; + FieldValue value = 3; +} + message Row { repeated FieldValue values = 1; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java index 81b7922fe495..d7abf4e8c93e 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java @@ -163,7 +163,7 @@ public RowCoder fromComponents(List> components, byte[] payload) { components.isEmpty(), "Expected empty component list, but received: " + components); Schema schema; try { - schema = SchemaTranslation.fromProto(SchemaApi.Schema.parseFrom(payload)); + schema = SchemaTranslation.schemaFromProto(SchemaApi.Schema.parseFrom(payload)); } catch (InvalidProtocolBufferException e) { throw new RuntimeException("Unable to parse schema for RowCoder: ", e); } diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java index 4ec1a7d58a0c..2ac9db5ee364 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java @@ -302,7 +302,8 @@ private static Object convertValue(Object value, CommonCoder coderSpec, Coder co } else if (s.equals(getUrn(StandardCoders.Enum.ROW))) { Schema schema; try { - schema = SchemaTranslation.fromProto(SchemaApi.Schema.parseFrom(coderSpec.getPayload())); + schema = + SchemaTranslation.schemaFromProto(SchemaApi.Schema.parseFrom(coderSpec.getPayload())); } catch (InvalidProtocolBufferException e) { throw new RuntimeException("Failed to parse schema payload for row coder", e); } diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SchemaTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SchemaTranslationTest.java index 6022d70999cc..bfb9adb4988b 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SchemaTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SchemaTranslationTest.java @@ -20,12 +20,17 @@ import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertThat; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.apache.beam.model.pipeline.v1.SchemaApi; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.SchemaTranslation; import org.apache.beam.sdk.schemas.logicaltypes.FixedBytes; +import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.junit.Test; import org.junit.runner.RunWith; @@ -41,6 +46,37 @@ public class SchemaTranslationTest { public static class ToFromProtoTest { @Parameters(name = "{index}: {0}") public static Iterable data() { + Map optionMap = new HashMap<>(); + optionMap.put("string", 42); + List optionList = new ArrayList<>(); + optionList.add("string"); + Row optionRow = + Row.withSchema( + Schema.builder() + .addField("field_one", FieldType.STRING) + .addField("field_two", FieldType.INT32) + .build()) + .addValue("value") + .addValue(42) + .build(); + + Schema.Options.Builder optionsBuilder = + Schema.Options.builder() + .setOption("field_option_boolean", FieldType.BOOLEAN, true) + .setOption("field_option_byte", FieldType.BYTE, (byte) 12) + .setOption("field_option_int16", FieldType.INT16, (short) 12) + .setOption("field_option_int32", FieldType.INT32, 12) + .setOption("field_option_int64", FieldType.INT64, 12L) + .setOption("field_option_string", FieldType.STRING, "foo") + .setOption("field_option_bytes", FieldType.BYTES, new byte[] {0x42, 0x69, 0x00}) + .setOption("field_option_float", FieldType.FLOAT, (float) 12.0) + .setOption("field_option_double", FieldType.DOUBLE, 12.0) + .setOption( + "field_option_map", FieldType.map(FieldType.STRING, FieldType.INT32), optionMap) + .setOption("field_option_array", FieldType.array(FieldType.STRING), optionList) + .setRowOption("field_option_row", optionRow) + .setOption("field_option_value", FieldType.STRING, "other"); + return ImmutableList.builder() .add(Schema.of(Field.of("string", FieldType.STRING))) .add( @@ -77,6 +113,49 @@ public static Iterable data() { Schema.of( Field.of("decimal", FieldType.DECIMAL), Field.of("datetime", FieldType.DATETIME))) .add(Schema.of(Field.of("logical", FieldType.logicalType(FixedBytes.of(24))))) + .add( + Schema.of( + Field.of("field_with_option_atomic", FieldType.STRING) + .withOptions( + Schema.Options.builder() + .setOption( + "field_option_atomic", FieldType.INT32, Integer.valueOf(42)) + .build())) + .withOptions( + Schema.Options.builder() + .setOption("schema_option_atomic", FieldType.BOOLEAN, true))) + .add( + Schema.of( + Field.of("field_with_option_map", FieldType.STRING) + .withOptions( + Schema.Options.builder() + .setOption( + "field_option_map", + FieldType.map(FieldType.STRING, FieldType.INT32), + optionMap))) + .withOptions( + Schema.Options.builder() + .setOption( + "field_option_map", + FieldType.map(FieldType.STRING, FieldType.INT32), + optionMap))) + .add( + Schema.of( + Field.of("field_with_option_array", FieldType.STRING) + .withOptions( + Schema.Options.builder() + .setOption( + "field_option_array", + FieldType.array(FieldType.STRING), + optionList) + .build())) + .withOptions( + Schema.Options.builder() + .setOption( + "field_option_array", FieldType.array(FieldType.STRING), optionList))) + .add( + Schema.of(Field.of("field", FieldType.STRING).withOptions(optionsBuilder)) + .withOptions(optionsBuilder)) .build(); } @@ -87,7 +166,7 @@ public static Iterable data() { public void toAndFromProto() throws Exception { SchemaApi.Schema schemaProto = SchemaTranslation.schemaToProto(schema, true); - Schema decodedSchema = SchemaTranslation.fromProto(schemaProto); + Schema decodedSchema = SchemaTranslation.schemaFromProto(schemaProto); assertThat(decodedSchema, equalTo(schema)); } } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/SchemaCoderCloudObjectTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/SchemaCoderCloudObjectTranslator.java index 8987abb37ca7..4ea07178228d 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/SchemaCoderCloudObjectTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/SchemaCoderCloudObjectTranslator.java @@ -91,7 +91,7 @@ public SchemaCoder fromCloudObject(CloudObject cloudObject) { SchemaApi.Schema protoSchema = SchemaApi.Schema.parseFrom( StringUtils.jsonStringToByteArray(Structs.getString(cloudObject, SCHEMA))); - Schema schema = SchemaTranslation.fromProto(protoSchema); + Schema schema = SchemaTranslation.schemaFromProto(protoSchema); return SchemaCoder.of(schema, typeDescriptor, toRowFunction, fromRowFunction); } catch (IOException e) { throw new RuntimeException(e); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java index 67eb832f13ba..5689ffc813f7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java @@ -17,6 +17,8 @@ */ package org.apache.beam.sdk.schemas; +import static org.apache.beam.sdk.values.SchemaVerification.verifyFieldValue; + import com.google.auto.value.AutoValue; import java.io.Serializable; import java.nio.charset.StandardCharsets; @@ -24,10 +26,12 @@ import java.util.Arrays; import java.util.Collections; import java.util.Comparator; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.TreeMap; import java.util.UUID; import java.util.stream.Collector; import java.util.stream.Collectors; @@ -36,6 +40,7 @@ import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.BiMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashBiMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; @@ -93,9 +98,12 @@ public String toString() { // equal, so we can short circuit comparison. @Nullable private UUID uuid = null; + private final Options options; + /** Builder class for building {@link Schema} objects. */ public static class Builder { List fields; + Options options = Options.none(); public Builder() { this.fields = Lists.newArrayList(); @@ -206,12 +214,23 @@ public Builder addMapField(String name, FieldType keyType, FieldType valueType) return this; } + /** Returns a copy of the Field with isNullable set. */ + public Builder setOptions(Options options) { + this.options = options; + return this; + } + + public Builder setOptions(Options.Builder optionsBuilder) { + this.options = optionsBuilder.build(); + return this; + } + public int getLastFieldId() { return fields.size() - 1; } public Schema build() { - return new Schema(fields); + return new Schema(fields, options); } } @@ -220,23 +239,37 @@ public static Builder builder() { } public Schema(List fields) { + this(fields, Options.none()); + } + + public Schema(List fields, Options options) { this.fields = fields; int index = 0; for (Field field : fields) { - if (fieldIndices.get(field.getName()) != null) { - throw new IllegalArgumentException( - "Duplicate field " + field.getName() + " added to schema"); - } + Preconditions.checkArgument( + fieldIndices.get(field.getName()) == null, + "Duplicate field " + field.getName() + " added to schema"); encodingPositions.put(field.getName(), index); fieldIndices.put(field.getName(), index++); } this.hashCode = Objects.hash(fieldIndices, fields); + this.options = options; } public static Schema of(Field... fields) { return Schema.builder().addFields(fields).build(); } + /** Returns a copy of the Schema with the options set. */ + public Schema withOptions(Options options) { + return new Schema(fields, getOptions().toBuilder().addOptions(options).build()); + } + + /** Returns a copy of the Schema with the options set. */ + public Schema withOptions(Options.Builder optionsBuilder) { + return withOptions(optionsBuilder.build()); + } + /** Set this schema's UUID. All schemas with the same UUID must be guaranteed to be identical. */ public void setUUID(UUID uuid) { this.uuid = uuid; @@ -273,7 +306,8 @@ public boolean equals(Object o) { return Objects.equals(uuid, other.uuid); } return Objects.equals(fieldIndices, other.fieldIndices) - && Objects.equals(getFields(), other.getFields()); + && Objects.equals(getFields(), other.getFields()) + && Objects.equals(getOptions(), other.getOptions()); } /** Returns true if two schemas are equal ignoring field names and descriptions. */ @@ -350,6 +384,8 @@ public String toString() { builder.append(field); builder.append(System.lineSeparator()); } + builder.append("Options:"); + builder.append(options); return builder.toString(); }; @@ -637,24 +673,24 @@ public static FieldType of(TypeName typeName) { public static final FieldType DATETIME = FieldType.of(TypeName.DATETIME); /** Create an array type for the given field type. */ - public static final FieldType array(FieldType elementType) { + public static FieldType array(FieldType elementType) { return FieldType.forTypeName(TypeName.ARRAY).setCollectionElementType(elementType).build(); } /** @deprecated Set the nullability on the elementType instead */ @Deprecated - public static final FieldType array(FieldType elementType, boolean nullable) { + public static FieldType array(FieldType elementType, boolean nullable) { return FieldType.forTypeName(TypeName.ARRAY) .setCollectionElementType(elementType.withNullable(nullable)) .build(); } - public static final FieldType iterable(FieldType elementType) { + public static FieldType iterable(FieldType elementType) { return FieldType.forTypeName(TypeName.ITERABLE).setCollectionElementType(elementType).build(); } /** Create a map type for the given key and value types. */ - public static final FieldType map(FieldType keyType, FieldType valueType) { + public static FieldType map(FieldType keyType, FieldType valueType) { return FieldType.forTypeName(TypeName.MAP) .setMapKeyType(keyType) .setMapValueType(valueType) @@ -663,8 +699,7 @@ public static final FieldType map(FieldType keyType, FieldType valueType) { /** @deprecated Set the nullability on the valueType instead */ @Deprecated - public static final FieldType map( - FieldType keyType, FieldType valueType, boolean valueTypeNullable) { + public static FieldType map(FieldType keyType, FieldType valueType, boolean valueTypeNullable) { return FieldType.forTypeName(TypeName.MAP) .setMapKeyType(keyType) .setMapValueType(valueType.withNullable(valueTypeNullable)) @@ -672,13 +707,12 @@ public static final FieldType map( } /** Create a map type for the given key and value types. */ - public static final FieldType row(Schema schema) { + public static FieldType row(Schema schema) { return FieldType.forTypeName(TypeName.ROW).setRowSchema(schema).build(); } /** Creates a logical type based on a primitive field type. */ - public static final FieldType logicalType( - LogicalType logicalType) { + public static FieldType logicalType(LogicalType logicalType) { return FieldType.forTypeName(TypeName.LOGICAL_TYPE).setLogicalType(logicalType).build(); } @@ -874,6 +908,9 @@ public abstract static class Field implements Serializable { /** Returns the fields {@link FieldType}. */ public abstract FieldType getType(); + /** Returns the fields {@link Options}. */ + public abstract Options getOptions(); + public abstract Builder toBuilder(); /** Builder for {@link Field}. */ @@ -885,6 +922,13 @@ public abstract static class Builder { public abstract Builder setType(FieldType fieldType); + public abstract Builder setOptions(Options options); + + public Builder setOptions(Options.Builder optionsBuilder) { + setOptions(optionsBuilder.build()); + return this; + } + public abstract Field build(); } @@ -894,6 +938,7 @@ public static Field of(String name, FieldType fieldType) { .setName(name) .setDescription("") .setType(fieldType) + .setOptions(Options.none()) .build(); } @@ -903,6 +948,7 @@ public static Field nullable(String name, FieldType fieldType) { .setName(name) .setDescription("") .setType(fieldType.withNullable(true)) + .setOptions(Options.none()) .build(); } @@ -926,6 +972,16 @@ public Field withNullable(boolean isNullable) { return toBuilder().setType(getType().withNullable(isNullable)).build(); } + /** Returns a copy of the Field with the options set. */ + public Field withOptions(Options options) { + return toBuilder().setOptions(getOptions().toBuilder().addOptions(options).build()).build(); + } + + /** Returns a copy of the Field with the options set. */ + public Field withOptions(Options.Builder optionsBuilder) { + return withOptions(optionsBuilder.build()); + } + @Override public boolean equals(Object o) { if (!(o instanceof Field)) { @@ -934,7 +990,8 @@ public boolean equals(Object o) { Field other = (Field) o; return Objects.equals(getName(), other.getName()) && Objects.equals(getDescription(), other.getDescription()) - && Objects.equals(getType(), other.getType()); + && Objects.equals(getType(), other.getType()) + && Objects.equals(getOptions(), other.getOptions()); } /** Returns true if two fields are equal, ignoring name and description. */ @@ -953,6 +1010,209 @@ public int hashCode() { } } + public static class Options implements Serializable { + private final Map options; + + @Override + public String toString() { + TreeMap sorted = new TreeMap(options); + return "{" + sorted + '}'; + } + + Map getAllOptions() { + return options; + } + + public Set getOptionNames() { + return options.keySet(); + } + + public boolean hasOptions() { + return options.size() > 0; + } + + public boolean hasOption(String name) { + return options.containsKey(name); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Options options1 = (Options) o; + if (!options.keySet().equals(options1.options.keySet())) { + return false; + } + for (Map.Entry optionEntry : options.entrySet()) { + Option thisOption = optionEntry.getValue(); + Option otherOption = options1.options.get(optionEntry.getKey()); + if (!thisOption.equals(otherOption)) { + return false; + } + } + return true; + } + + @Override + public int hashCode() { + return Objects.hash(options); + } + + static class Option implements Serializable { + Option(FieldType type, Object value) { + this.type = type; + this.value = value; + } + + private FieldType type; + private Object value; + + @SuppressWarnings("TypeParameterUnusedInFormals") + T getValue() { + return (T) value; + } + + FieldType getType() { + return type; + } + + @Override + public String toString() { + return "Option{type=" + type + ", value=" + value + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Option option = (Option) o; + return Objects.equals(type, option.type) + && Row.Equals.deepEquals(value, option.value, type); + } + + @Override + public int hashCode() { + return Row.Equals.deepHashCode(value, type); + } + } + + public static class Builder { + private Map options; + + Builder(Map init) { + this.options = new HashMap<>(init); + } + + Builder() { + this(new HashMap<>()); + } + + public Builder setRowOption(String optionName, Row value) { + setOption(optionName, FieldType.row(value.getSchema()), value); + return this; + } + + public Builder setOption(String optionName, FieldType fieldType, Object value) { + if (value == null) { + if (fieldType.getNullable()) { + options.put(optionName, new Option(fieldType, null)); + } else { + throw new IllegalArgumentException( + String.format("Option %s is not nullable", optionName)); + } + } else { + options.put( + optionName, new Option(fieldType, verifyFieldValue(value, fieldType, optionName))); + } + return this; + } + + public Options build() { + return new Options(this.options); + } + + public Builder addOptions(Options options) { + this.options.putAll(options.options); + return this; + } + } + + Options(Map options) { + this.options = options; + } + + Options() { + this.options = new HashMap<>(); + } + + Options.Builder toBuilder() { + return new Builder(new HashMap<>(this.options)); + } + + public static Options.Builder builder() { + return new Builder(); + } + + public static Options none() { + return new Options(); + } + + /** Get the value of an option. If the option is not found null is returned. */ + @SuppressWarnings("TypeParameterUnusedInFormals") + @Nullable + public T getValue(String optionName) { + Option option = options.get(optionName); + if (option != null) { + return option.getValue(); + } + throw new IllegalArgumentException( + String.format("No option found with name %s.", optionName)); + } + + /** Get the value of an option. If the option is not found null is returned. */ + @Nullable + public T getValue(String optionName, Class valueClass) { + return getValue(optionName); + } + + /** Get the value of an option. If the option is not found the default value is returned. */ + @Nullable + public T getValueOrDefault(String optionName, T defaultValue) { + Option option = options.get(optionName); + if (option != null) { + return option.getValue(); + } + return defaultValue; + } + + /** Get the type of an option. */ + @Nullable + public FieldType getType(String optionName) { + Option option = options.get(optionName); + if (option != null) { + return option.getType(); + } + throw new IllegalArgumentException( + String.format("No option found with name %s.", optionName)); + } + + public static Options.Builder setOption(String optionName, FieldType fieldType, Object value) { + return Options.builder().setOption(optionName, fieldType, value); + } + + public static Options.Builder setRowOption(String optionName, Row value) { + return Options.builder().setRowOption(optionName, value); + } + } + /** Collects a stream of {@link Field}s into a {@link Schema}. */ public static Collector, Schema> toSchema() { return Collector.of( @@ -986,10 +1246,8 @@ public Field getField(String name) { /** Find the index of a given field. */ public int indexOf(String fieldName) { Integer index = fieldIndices.get(fieldName); - if (index == null) { - throw new IllegalArgumentException( - String.format("Cannot find field %s in schema %s", fieldName, this)); - } + Preconditions.checkArgument( + index != null, String.format("Cannot find field %s in schema %s", fieldName, this)); return index; } @@ -1001,9 +1259,7 @@ public boolean hasField(String fieldName) { /** Return the name of field by index. */ public String nameOf(int fieldIndex) { String name = fieldIndices.inverse().get(fieldIndex); - if (name == null) { - throw new IllegalArgumentException(String.format("Cannot find field %d", fieldIndex)); - } + Preconditions.checkArgument(name != null, String.format("Cannot find field %d", fieldIndex)); return name; } @@ -1011,4 +1267,8 @@ public String nameOf(int fieldIndex) { public int getFieldCount() { return getFields().size(); } + + public Options getOptions() { + return this.options; + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java index 0d19b65d2af0..d311643b7e62 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaTranslation.java @@ -17,10 +17,15 @@ */ package org.apache.beam.sdk.schemas; +import java.util.ArrayList; +import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.UUID; +import java.util.stream.Collectors; import org.apache.beam.model.pipeline.v1.SchemaApi; import org.apache.beam.model.pipeline.v1.SchemaApi.ArrayTypeValue; +import org.apache.beam.model.pipeline.v1.SchemaApi.AtomicTypeValue; import org.apache.beam.model.pipeline.v1.SchemaApi.FieldValue; import org.apache.beam.model.pipeline.v1.SchemaApi.IterableTypeValue; import org.apache.beam.model.pipeline.v1.SchemaApi.MapTypeEntry; @@ -57,6 +62,7 @@ public static SchemaApi.Schema schemaToProto(Schema schema, boolean serializeLog serializeLogicalType); builder.addFields(protoField); } + builder.addAllOptions(optionsToProto(schema.getOptions())); return builder.build(); } @@ -68,6 +74,7 @@ private static SchemaApi.Field fieldToProto( .setType(fieldTypeToProto(field.getType(), serializeLogicalType)) .setId(fieldId) .setEncodingPosition(position) + .addAllOptions(optionsToProto(field.getOptions())) .build(); } @@ -110,7 +117,7 @@ private static SchemaApi.FieldType fieldTypeToProto( .setArgumentType( fieldTypeToProto(logicalType.getArgumentType(), serializeLogicalType)) .setArgument( - rowFieldToProto(logicalType.getArgumentType(), logicalType.getArgument())) + fieldValueToProto(logicalType.getArgumentType(), logicalType.getArgument())) .setRepresentation( fieldTypeToProto(logicalType.getBaseType(), serializeLogicalType)) // TODO(BEAM-7855): "javasdk" types should only be a last resort. Types defined in @@ -172,7 +179,7 @@ private static SchemaApi.FieldType fieldTypeToProto( return builder.build(); } - public static Schema fromProto(SchemaApi.Schema protoSchema) { + public static Schema schemaFromProto(SchemaApi.Schema protoSchema) { Schema.Builder builder = Schema.builder(); Map encodingLocationMap = Maps.newHashMap(); for (SchemaApi.Field protoField : protoSchema.getFieldsList()) { @@ -180,17 +187,18 @@ public static Schema fromProto(SchemaApi.Schema protoSchema) { builder.addField(field); encodingLocationMap.put(protoField.getName(), protoField.getEncodingPosition()); } + builder.setOptions(optionsFromProto(protoSchema.getOptionsList())); Schema schema = builder.build(); schema.setEncodingPositions(encodingLocationMap); if (!protoSchema.getId().isEmpty()) { schema.setUUID(UUID.fromString(protoSchema.getId())); } - return schema; } private static Field fieldFromProto(SchemaApi.Field protoField) { return Field.of(protoField.getName(), fieldTypeFromProto(protoField.getType())) + .withOptions(optionsFromProto(protoField.getOptionsList())) .withDescription(protoField.getDescription()); } @@ -233,7 +241,7 @@ private static FieldType fieldTypeFromProtoWithoutNullable(SchemaApi.FieldType p "Encountered unknown AtomicType: " + protoFieldType.getAtomicType()); } case ROW_TYPE: - return FieldType.row(fromProto(protoFieldType.getRowType().getSchema())); + return FieldType.row(schemaFromProto(protoFieldType.getRowType().getSchema())); case ARRAY_TYPE: return FieldType.array(fieldTypeFromProto(protoFieldType.getArrayType().getElementType())); case ITERABLE_TYPE: @@ -268,12 +276,21 @@ private static FieldType fieldTypeFromProtoWithoutNullable(SchemaApi.FieldType p public static SchemaApi.Row rowToProto(Row row) { SchemaApi.Row.Builder builder = SchemaApi.Row.newBuilder(); for (int i = 0; i < row.getFieldCount(); ++i) { - builder.addValues(rowFieldToProto(row.getSchema().getField(i).getType(), row.getValue(i))); + builder.addValues(fieldValueToProto(row.getSchema().getField(i).getType(), row.getValue(i))); + } + return builder.build(); + } + + public static Object rowFromProto(SchemaApi.Row row, FieldType fieldType) { + Row.Builder builder = Row.withSchema(fieldType.getRowSchema()); + for (int i = 0; i < row.getValuesCount(); ++i) { + builder.addValue( + fieldValueFromProto(fieldType.getRowSchema().getField(i).getType(), row.getValues(i))); } return builder.build(); } - private static SchemaApi.FieldValue rowFieldToProto(FieldType fieldType, Object value) { + static SchemaApi.FieldValue fieldValueToProto(FieldType fieldType, Object value) { FieldValue.Builder builder = FieldValue.newBuilder(); switch (fieldType.getTypeName()) { case ARRAY: @@ -299,59 +316,146 @@ private static SchemaApi.FieldValue rowFieldToProto(FieldType fieldType, Object } } + static Object fieldValueFromProto(FieldType fieldType, SchemaApi.FieldValue value) { + switch (fieldType.getTypeName()) { + case ARRAY: + return arrayValueFromProto(fieldType.getCollectionElementType(), value.getArrayValue()); + case ITERABLE: + return iterableValueFromProto( + fieldType.getCollectionElementType(), value.getIterableValue()); + case MAP: + return mapFromProto( + fieldType.getMapKeyType(), fieldType.getMapValueType(), value.getMapValue()); + case ROW: + return rowFromProto(value.getRowValue(), fieldType); + case LOGICAL_TYPE: + default: + return primitiveFromProto(fieldType, value.getAtomicValue()); + } + } + private static SchemaApi.ArrayTypeValue arrayValueToProto( FieldType elementType, Iterable values) { return ArrayTypeValue.newBuilder() - .addAllElement(Iterables.transform(values, e -> rowFieldToProto(elementType, e))) + .addAllElement(Iterables.transform(values, e -> fieldValueToProto(elementType, e))) .build(); } + private static Iterable arrayValueFromProto( + FieldType elementType, SchemaApi.ArrayTypeValue values) { + return values.getElementList().stream() + .map(e -> fieldValueFromProto(elementType, e)) + .collect(Collectors.toList()); + } + private static SchemaApi.IterableTypeValue iterableValueToProto( FieldType elementType, Iterable values) { return IterableTypeValue.newBuilder() - .addAllElement(Iterables.transform(values, e -> rowFieldToProto(elementType, e))) + .addAllElement(Iterables.transform(values, e -> fieldValueToProto(elementType, e))) .build(); } + private static Object iterableValueFromProto(FieldType elementType, IterableTypeValue values) { + return values.getElementList().stream() + .map(e -> fieldValueFromProto(elementType, e)) + .collect(Collectors.toList()); + } + private static SchemaApi.MapTypeValue mapToProto( FieldType keyType, FieldType valueType, Map map) { MapTypeValue.Builder builder = MapTypeValue.newBuilder(); for (Map.Entry entry : map.entrySet()) { MapTypeEntry mapProtoEntry = MapTypeEntry.newBuilder() - .setKey(rowFieldToProto(keyType, entry.getKey())) - .setValue(rowFieldToProto(valueType, entry.getValue())) + .setKey(fieldValueToProto(keyType, entry.getKey())) + .setValue(fieldValueToProto(valueType, entry.getValue())) .build(); builder.addEntries(mapProtoEntry); } return builder.build(); } - private static SchemaApi.AtomicTypeValue primitiveRowFieldToProto( - FieldType fieldType, Object value) { + private static Object mapFromProto( + FieldType mapKeyType, FieldType mapValueType, MapTypeValue mapValue) { + return mapValue.getEntriesList().stream() + .collect( + Collectors.toMap( + entry -> fieldValueFromProto(mapKeyType, entry.getKey()), + entry -> fieldValueFromProto(mapValueType, entry.getValue()))); + } + + private static AtomicTypeValue primitiveRowFieldToProto(FieldType fieldType, Object value) { switch (fieldType.getTypeName()) { case BYTE: - return SchemaApi.AtomicTypeValue.newBuilder().setByte((int) value).build(); + return AtomicTypeValue.newBuilder().setByte((byte) value).build(); case INT16: - return SchemaApi.AtomicTypeValue.newBuilder().setInt16((int) value).build(); + return AtomicTypeValue.newBuilder().setInt16((short) value).build(); case INT32: - return SchemaApi.AtomicTypeValue.newBuilder().setInt32((int) value).build(); + return AtomicTypeValue.newBuilder().setInt32((int) value).build(); case INT64: - return SchemaApi.AtomicTypeValue.newBuilder().setInt64((long) value).build(); + return AtomicTypeValue.newBuilder().setInt64((long) value).build(); case FLOAT: - return SchemaApi.AtomicTypeValue.newBuilder().setFloat((float) value).build(); + return AtomicTypeValue.newBuilder().setFloat((float) value).build(); case DOUBLE: - return SchemaApi.AtomicTypeValue.newBuilder().setDouble((double) value).build(); + return AtomicTypeValue.newBuilder().setDouble((double) value).build(); case STRING: - return SchemaApi.AtomicTypeValue.newBuilder().setString((String) value).build(); + return AtomicTypeValue.newBuilder().setString((String) value).build(); case BOOLEAN: - return SchemaApi.AtomicTypeValue.newBuilder().setBoolean((boolean) value).build(); + return AtomicTypeValue.newBuilder().setBoolean((boolean) value).build(); case BYTES: - return SchemaApi.AtomicTypeValue.newBuilder() - .setBytes(ByteString.copyFrom((byte[]) value)) - .build(); + return AtomicTypeValue.newBuilder().setBytes(ByteString.copyFrom((byte[]) value)).build(); default: throw new RuntimeException("FieldType unexpected " + fieldType.getTypeName()); } } + + private static Object primitiveFromProto(FieldType fieldType, AtomicTypeValue value) { + switch (fieldType.getTypeName()) { + case BYTE: + return (byte) value.getByte(); + case INT16: + return (short) value.getInt16(); + case INT32: + return value.getInt32(); + case INT64: + return value.getInt64(); + case FLOAT: + return value.getFloat(); + case DOUBLE: + return value.getDouble(); + case STRING: + return value.getString(); + case BOOLEAN: + return value.getBoolean(); + case BYTES: + return value.getBytes().toByteArray(); + default: + throw new RuntimeException("FieldType unexpected " + fieldType.getTypeName()); + } + } + + private static List optionsToProto(Schema.Options options) { + List protoOptions = new ArrayList<>(); + for (String name : options.getOptionNames()) { + protoOptions.add( + SchemaApi.Option.newBuilder() + .setName(name) + .setType(fieldTypeToProto(Objects.requireNonNull(options.getType(name)), false)) + .setValue( + fieldValueToProto( + Objects.requireNonNull(options.getType(name)), options.getValue(name))) + .build()); + } + return protoOptions; + } + + private static Schema.Options optionsFromProto(List protoOptions) { + Schema.Options.Builder optionBuilder = Schema.Options.builder(); + for (SchemaApi.Option protoOption : protoOptions) { + FieldType fieldType = fieldTypeFromProto(protoOption.getType()); + optionBuilder.setOption( + protoOption.getName(), fieldType, fieldValueFromProto(fieldType, protoOption.getValue())); + } + return optionBuilder.build(); + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index 9c35657c7ce4..8f4cf6243874 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -17,12 +17,12 @@ */ package org.apache.beam.sdk.values; +import static org.apache.beam.sdk.values.SchemaVerification.verifyRowValues; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; import java.io.Serializable; import java.math.BigDecimal; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -30,7 +30,6 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Objects; import java.util.stream.Collector; import javax.annotation.Nullable; @@ -39,17 +38,13 @@ import org.apache.beam.sdk.schemas.Factory; import org.apache.beam.sdk.schemas.FieldValueGetter; import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.LogicalType; import org.apache.beam.sdk.schemas.Schema.TypeName; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.joda.time.DateTime; -import org.joda.time.Instant; import org.joda.time.ReadableDateTime; import org.joda.time.ReadableInstant; -import org.joda.time.base.AbstractInstant; /** * {@link Row} is an immutable tuple-like schema to represent one element in a {@link PCollection}. @@ -434,7 +429,7 @@ public static boolean deepEquals(Object a, Object b, Schema.FieldType fieldType) } } - static int deepHashCode(Object a, Schema.FieldType fieldType) { + public static int deepHashCode(Object a, Schema.FieldType fieldType) { if (a == null) { return 0; } else if (fieldType.getTypeName() == TypeName.LOGICAL_TYPE) { @@ -614,230 +609,13 @@ public Builder withFieldValueGetters( return this; } - private List verify(Schema schema, List values) { - List verifiedValues = Lists.newArrayListWithCapacity(values.size()); - if (schema.getFieldCount() != values.size()) { - throw new IllegalArgumentException( - String.format( - "Field count in Schema (%s) (%d) and values (%s) (%d) must match", - schema.getFieldNames(), schema.getFieldCount(), values, values.size())); - } - for (int i = 0; i < values.size(); ++i) { - Object value = values.get(i); - Schema.Field field = schema.getField(i); - if (value == null) { - if (!field.getType().getNullable()) { - throw new IllegalArgumentException( - String.format("Field %s is not nullable", field.getName())); - } - verifiedValues.add(null); - } else { - verifiedValues.add(verify(value, field.getType(), field.getName())); - } - } - return verifiedValues; - } - - private Object verify(Object value, FieldType type, String fieldName) { - if (TypeName.ARRAY.equals(type.getTypeName())) { - return verifyArray(value, type.getCollectionElementType(), fieldName); - } else if (TypeName.ITERABLE.equals(type.getTypeName())) { - return verifyIterable(value, type.getCollectionElementType(), fieldName); - } - if (TypeName.MAP.equals(type.getTypeName())) { - return verifyMap(value, type.getMapKeyType(), type.getMapValueType(), fieldName); - } else if (TypeName.ROW.equals(type.getTypeName())) { - return verifyRow(value, fieldName); - } else if (TypeName.LOGICAL_TYPE.equals(type.getTypeName())) { - return verifyLogicalType(value, type.getLogicalType(), fieldName); - } else { - return verifyPrimitiveType(value, type.getTypeName(), fieldName); - } - } - - private Object verifyLogicalType(Object value, LogicalType logicalType, String fieldName) { - return verify(logicalType.toBaseType(value), logicalType.getBaseType(), fieldName); - } - - private List verifyArray( - Object value, FieldType collectionElementType, String fieldName) { - boolean collectionElementTypeNullable = collectionElementType.getNullable(); - if (!(value instanceof Collection)) { - throw new IllegalArgumentException( - String.format( - "For field name %s and array type expected Collection class. Instead " - + "class type was %s.", - fieldName, value.getClass())); - } - Collection valueCollection = (Collection) value; - List verifiedList = Lists.newArrayListWithCapacity(valueCollection.size()); - for (Object listValue : valueCollection) { - if (listValue == null) { - if (!collectionElementTypeNullable) { - throw new IllegalArgumentException( - String.format( - "%s is not nullable in Array field %s", collectionElementType, fieldName)); - } - verifiedList.add(null); - } else { - verifiedList.add(verify(listValue, collectionElementType, fieldName)); - } - } - return verifiedList; - } - - private Iterable verifyIterable( - Object value, FieldType collectionElementType, String fieldName) { - boolean collectionElementTypeNullable = collectionElementType.getNullable(); - if (!(value instanceof Iterable)) { - throw new IllegalArgumentException( - String.format( - "For field name %s and iterable type expected class extending Iterable. Instead " - + "class type was %s.", - fieldName, value.getClass())); - } - Iterable valueIterable = (Iterable) value; - for (Object listValue : valueIterable) { - if (listValue == null) { - if (!collectionElementTypeNullable) { - throw new IllegalArgumentException( - String.format( - "%s is not nullable in Array field %s", collectionElementType, fieldName)); - } - } else { - verify(listValue, collectionElementType, fieldName); - } - } - return valueIterable; - } - - private Map verifyMap( - Object value, FieldType keyType, FieldType valueType, String fieldName) { - boolean valueTypeNullable = valueType.getNullable(); - if (!(value instanceof Map)) { - throw new IllegalArgumentException( - String.format( - "For field name %s and map type expected Map class. Instead " - + "class type was %s.", - fieldName, value.getClass())); - } - Map valueMap = (Map) value; - Map verifiedMap = Maps.newHashMapWithExpectedSize(valueMap.size()); - for (Entry kv : valueMap.entrySet()) { - if (kv.getValue() == null) { - if (!valueTypeNullable) { - throw new IllegalArgumentException( - String.format("%s is not nullable in Map field %s", valueType, fieldName)); - } - verifiedMap.put(verify(kv.getKey(), keyType, fieldName), null); - } else { - verifiedMap.put( - verify(kv.getKey(), keyType, fieldName), verify(kv.getValue(), valueType, fieldName)); - } - } - return verifiedMap; - } - - private Row verifyRow(Object value, String fieldName) { - if (!(value instanceof Row)) { - throw new IllegalArgumentException( - String.format( - "For field name %s expected Row type. " + "Instead class type was %s.", - fieldName, value.getClass())); - } - // No need to recursively validate the nested Row, since there's no way to build the - // Row object without it validating. - return (Row) value; - } - - private Object verifyPrimitiveType(Object value, TypeName type, String fieldName) { - if (type.isDateType()) { - return verifyDateTime(value, fieldName); - } else { - switch (type) { - case BYTE: - if (value instanceof Byte) { - return value; - } - break; - case BYTES: - if (value instanceof ByteBuffer) { - return ((ByteBuffer) value).array(); - } else if (value instanceof byte[]) { - return (byte[]) value; - } - break; - case INT16: - if (value instanceof Short) { - return value; - } - break; - case INT32: - if (value instanceof Integer) { - return value; - } - break; - case INT64: - if (value instanceof Long) { - return value; - } - break; - case DECIMAL: - if (value instanceof BigDecimal) { - return value; - } - break; - case FLOAT: - if (value instanceof Float) { - return value; - } - break; - case DOUBLE: - if (value instanceof Double) { - return value; - } - break; - case STRING: - if (value instanceof String) { - return value; - } - break; - case BOOLEAN: - if (value instanceof Boolean) { - return value; - } - break; - default: - // Shouldn't actually get here, but we need this case to satisfy linters. - throw new IllegalArgumentException( - String.format("Not a primitive type for field name %s: %s", fieldName, type)); - } - throw new IllegalArgumentException( - String.format( - "For field name %s and type %s found incorrect class type %s", - fieldName, type, value.getClass())); - } - } - - private Instant verifyDateTime(Object value, String fieldName) { - // We support the following classes for datetimes. - if (value instanceof AbstractInstant) { - return ((AbstractInstant) value).toInstant(); - } else { - throw new IllegalArgumentException( - String.format( - "For field name %s and DATETIME type got unexpected class %s ", - fieldName, value.getClass())); - } - } - public Row build() { checkNotNull(schema); if (!this.values.isEmpty() && fieldValueGetterFactory != null) { throw new IllegalArgumentException("Cannot specify both values and getters."); } if (!this.values.isEmpty()) { - List storageValues = attached ? this.values : verify(schema, this.values); + List storageValues = attached ? this.values : verifyRowValues(schema, this.values); checkState(getterTarget == null, "withGetterTarget requires getters."); return new RowWithStorage(schema, storageValues); } else if (fieldValueGetterFactory != null) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java new file mode 100644 index 000000000000..e06e7028bcb2 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/SchemaVerification.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.values; + +import java.io.Serializable; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.Schema.LogicalType; +import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.joda.time.Instant; +import org.joda.time.base.AbstractInstant; + +@Experimental +public abstract class SchemaVerification implements Serializable { + + static List verifyRowValues(Schema schema, List values) { + List verifiedValues = Lists.newArrayListWithCapacity(values.size()); + if (schema.getFieldCount() != values.size()) { + throw new IllegalArgumentException( + String.format( + "Field count in Schema (%s) (%d) and values (%s) (%d) must match", + schema.getFieldNames(), schema.getFieldCount(), values, values.size())); + } + for (int i = 0; i < values.size(); ++i) { + Object value = values.get(i); + Schema.Field field = schema.getField(i); + if (value == null) { + if (!field.getType().getNullable()) { + throw new IllegalArgumentException( + String.format("Field %s is not nullable", field.getName())); + } + verifiedValues.add(null); + } else { + verifiedValues.add(verifyFieldValue(value, field.getType(), field.getName())); + } + } + return verifiedValues; + } + + public static Object verifyFieldValue(Object value, FieldType type, String fieldName) { + if (TypeName.ARRAY.equals(type.getTypeName())) { + return verifyArray(value, type.getCollectionElementType(), fieldName); + } else if (TypeName.ITERABLE.equals(type.getTypeName())) { + return verifyIterable(value, type.getCollectionElementType(), fieldName); + } + if (TypeName.MAP.equals(type.getTypeName())) { + return verifyMap(value, type.getMapKeyType(), type.getMapValueType(), fieldName); + } else if (TypeName.ROW.equals(type.getTypeName())) { + return verifyRow(value, fieldName); + } else if (TypeName.LOGICAL_TYPE.equals(type.getTypeName())) { + return verifyLogicalType(value, type.getLogicalType(), fieldName); + } else { + return verifyPrimitiveType(value, type.getTypeName(), fieldName); + } + } + + private static Object verifyLogicalType(Object value, LogicalType logicalType, String fieldName) { + return verifyFieldValue(logicalType.toBaseType(value), logicalType.getBaseType(), fieldName); + } + + private static List verifyArray( + Object value, FieldType collectionElementType, String fieldName) { + boolean collectionElementTypeNullable = collectionElementType.getNullable(); + if (!(value instanceof List)) { + throw new IllegalArgumentException( + String.format( + "For field name %s and array type expected List class. Instead " + + "class type was %s.", + fieldName, value.getClass())); + } + List valueList = (List) value; + List verifiedList = Lists.newArrayListWithCapacity(valueList.size()); + for (Object listValue : valueList) { + if (listValue == null) { + if (!collectionElementTypeNullable) { + throw new IllegalArgumentException( + String.format( + "%s is not nullable in Array field %s", collectionElementType, fieldName)); + } + verifiedList.add(null); + } else { + verifiedList.add(verifyFieldValue(listValue, collectionElementType, fieldName)); + } + } + return verifiedList; + } + + private static Iterable verifyIterable( + Object value, FieldType collectionElementType, String fieldName) { + boolean collectionElementTypeNullable = collectionElementType.getNullable(); + if (!(value instanceof Iterable)) { + throw new IllegalArgumentException( + String.format( + "For field name %s and iterable type expected class extending Iterable. Instead " + + "class type was %s.", + fieldName, value.getClass())); + } + Iterable valueList = (Iterable) value; + for (Object listValue : valueList) { + if (listValue == null) { + if (!collectionElementTypeNullable) { + throw new IllegalArgumentException( + String.format( + "%s is not nullable in Array field %s", collectionElementType, fieldName)); + } + } else { + verifyFieldValue(listValue, collectionElementType, fieldName); + } + } + return valueList; + } + + private static Map verifyMap( + Object value, FieldType keyType, FieldType valueType, String fieldName) { + boolean valueTypeNullable = valueType.getNullable(); + if (!(value instanceof Map)) { + throw new IllegalArgumentException( + String.format( + "For field name %s and map type expected Map class. Instead " + "class type was %s.", + fieldName, value.getClass())); + } + Map valueMap = (Map) value; + Map verifiedMap = Maps.newHashMapWithExpectedSize(valueMap.size()); + for (Entry kv : valueMap.entrySet()) { + if (kv.getValue() == null) { + if (!valueTypeNullable) { + throw new IllegalArgumentException( + String.format("%s is not nullable in Map field %s", valueType, fieldName)); + } + verifiedMap.put(verifyFieldValue(kv.getKey(), keyType, fieldName), null); + } else { + verifiedMap.put( + verifyFieldValue(kv.getKey(), keyType, fieldName), + verifyFieldValue(kv.getValue(), valueType, fieldName)); + } + } + return verifiedMap; + } + + private static Row verifyRow(Object value, String fieldName) { + if (!(value instanceof Row)) { + throw new IllegalArgumentException( + String.format( + "For field name %s expected Row type. " + "Instead class type was %s.", + fieldName, value.getClass())); + } + // No need to recursively validate the nested Row, since there's no way to build the + // Row object without it validating. + return (Row) value; + } + + private static Object verifyPrimitiveType(Object value, TypeName type, String fieldName) { + if (type.isDateType()) { + return verifyDateTime(value, fieldName); + } else { + switch (type) { + case BYTE: + if (value instanceof Byte) { + return value; + } + break; + case BYTES: + if (value instanceof ByteBuffer) { + return ((ByteBuffer) value).array(); + } else if (value instanceof byte[]) { + return (byte[]) value; + } + break; + case INT16: + if (value instanceof Short) { + return value; + } + break; + case INT32: + if (value instanceof Integer) { + return value; + } + break; + case INT64: + if (value instanceof Long) { + return value; + } + break; + case DECIMAL: + if (value instanceof BigDecimal) { + return value; + } + break; + case FLOAT: + if (value instanceof Float) { + return value; + } + break; + case DOUBLE: + if (value instanceof Double) { + return value; + } + break; + case STRING: + if (value instanceof String) { + return value; + } + break; + case BOOLEAN: + if (value instanceof Boolean) { + return value; + } + break; + default: + // Shouldn't actually get here, but we need this case to satisfy linters. + throw new IllegalArgumentException( + String.format("Not a primitive type for field name %s: %s", fieldName, type)); + } + throw new IllegalArgumentException( + String.format( + "For field name %s and type %s found incorrect class type %s", + fieldName, type, value.getClass())); + } + } + + private static Instant verifyDateTime(Object value, String fieldName) { + // We support the following classes for datetimes. + if (value instanceof AbstractInstant) { + return ((AbstractInstant) value).toInstant(); + } else { + throw new IllegalArgumentException( + String.format( + "For field name %s and DATETIME type got unexpected class %s ", + fieldName, value.getClass())); + } + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaOptionsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaOptionsTest.java new file mode 100644 index 000000000000..f3e3685f98c8 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/SchemaOptionsTest.java @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.schemas; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.values.Row; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +/** Unit tests for {@link Schema.Options}. */ +public class SchemaOptionsTest { + + private static final String OPTION_NAME = "beam:test:field_i1"; + private static final String FIELD_NAME = "f_field"; + private static final Field FIELD = Field.of(FIELD_NAME, FieldType.STRING); + private static final Row TEST_ROW = + Row.withSchema( + Schema.builder() + .addField("field_one", FieldType.STRING) + .addField("field_two", FieldType.INT32) + .build()) + .addValue("value") + .addValue(42) + .build(); + + private static final Map TEST_MAP = new HashMap<>(); + + static { + TEST_MAP.put(1, "one"); + TEST_MAP.put(2, "two"); + } + + private static final List TEST_LIST = new ArrayList<>(); + + static { + TEST_LIST.add("one"); + TEST_LIST.add("two"); + TEST_LIST.add("three"); + } + + private static final byte[] TEST_BYTES = new byte[] {0x42, 0x69, 0x00}; + + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void testBooleanOption() { + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.BOOLEAN, true).build(); + assertEquals(true, options.getValue(OPTION_NAME)); + assertEquals(FieldType.BOOLEAN, options.getType(OPTION_NAME)); + } + + @Test + public void testInt16Option() { + Schema.Options options = + Schema.Options.setOption(OPTION_NAME, FieldType.INT16, (short) 12).build(); + assertEquals(Short.valueOf((short) 12), options.getValue(OPTION_NAME)); + assertEquals(FieldType.INT16, options.getType(OPTION_NAME)); + } + + @Test + public void testByteOption() { + Schema.Options options = + Schema.Options.setOption(OPTION_NAME, FieldType.BYTE, (byte) 12).build(); + assertEquals(Byte.valueOf((byte) 12), options.getValue(OPTION_NAME)); + assertEquals(FieldType.BYTE, options.getType(OPTION_NAME)); + } + + @Test + public void testInt32Option() { + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.INT32, 12).build(); + assertEquals(Integer.valueOf(12), options.getValue(OPTION_NAME)); + assertEquals(FieldType.INT32, options.getType(OPTION_NAME)); + } + + @Test + public void testInt64Option() { + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.INT64, 12L).build(); + assertEquals(Long.valueOf(12), options.getValue(OPTION_NAME)); + assertEquals(FieldType.INT64, options.getType(OPTION_NAME)); + } + + @Test + public void testFloatOption() { + Schema.Options options = + Schema.Options.setOption(OPTION_NAME, FieldType.FLOAT, (float) 12.0).build(); + assertEquals(Float.valueOf((float) 12.0), options.getValue(OPTION_NAME)); + assertEquals(FieldType.FLOAT, options.getType(OPTION_NAME)); + } + + @Test + public void testDoubleOption() { + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.DOUBLE, 12.0).build(); + assertEquals(Double.valueOf(12.0), options.getValue(OPTION_NAME)); + assertEquals(FieldType.DOUBLE, options.getType(OPTION_NAME)); + } + + @Test + public void testStringOption() { + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.STRING, "foo").build(); + assertEquals("foo", options.getValue(OPTION_NAME)); + assertEquals(FieldType.STRING, options.getType(OPTION_NAME)); + } + + @Test + public void testBytesOption() { + byte[] bytes = new byte[] {0x42, 0x69, 0x00}; + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.BYTES, bytes).build(); + assertEquals(bytes, options.getValue(OPTION_NAME)); + assertEquals(FieldType.BYTES, options.getType(OPTION_NAME)); + } + + @Test + public void testDateTimeOption() { + DateTime now = DateTime.now().withZone(DateTimeZone.UTC); + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.DATETIME, now).build(); + assertEquals(now, options.getValue(OPTION_NAME)); + assertEquals(FieldType.DATETIME, options.getType(OPTION_NAME)); + } + + @Test + public void testArrayOfIntegerOption() { + Schema.Options options = + Schema.Options.setOption(OPTION_NAME, FieldType.array(FieldType.STRING), TEST_LIST).build(); + assertEquals(TEST_LIST, options.getValue(OPTION_NAME)); + assertEquals(FieldType.array(FieldType.STRING), options.getType(OPTION_NAME)); + } + + @Test + public void testMapOfIntegerStringOption() { + Schema.Options options = + Schema.Options.setOption( + OPTION_NAME, FieldType.map(FieldType.INT32, FieldType.STRING), TEST_MAP) + .build(); + assertEquals(TEST_MAP, options.getValue(OPTION_NAME)); + assertEquals(FieldType.map(FieldType.INT32, FieldType.STRING), options.getType(OPTION_NAME)); + } + + @Test + public void testRowOption() { + Schema.Options options = Schema.Options.setRowOption(OPTION_NAME, TEST_ROW).build(); + assertEquals(TEST_ROW, options.getValue(OPTION_NAME)); + assertEquals(FieldType.row(TEST_ROW.getSchema()), options.getType(OPTION_NAME)); + } + + @Test(expected = IllegalArgumentException.class) + public void testNotNullableOptionSetNull() { + Schema.Options options = Schema.Options.setOption(OPTION_NAME, FieldType.STRING, null).build(); + } + + @Test + public void testNullableOptionSetNull() { + Schema.Options options = + Schema.Options.setOption(OPTION_NAME, FieldType.STRING.withNullable(true), null).build(); + assertNull(options.getValue(OPTION_NAME)); + assertEquals(FieldType.STRING.withNullable(true), options.getType(OPTION_NAME)); + } + + @Test(expected = IllegalArgumentException.class) + public void testGetValueNoOption() { + Schema.Options options = Schema.Options.none(); + options.getValue("foo"); + } + + @Test(expected = IllegalArgumentException.class) + public void testGetTypeNoOption() { + Schema.Options options = Schema.Options.none(); + options.getType("foo"); + } + + @Test + public void testGetValueOrDefault() { + Schema.Options options = Schema.Options.none(); + assertNull(options.getValueOrDefault("foo", null)); + } + + private Schema.Options.Builder setOptionsSet1() { + return setOptionsSet1(Schema.Options.builder()); + } + + private Schema.Options.Builder setOptionsSet1(Schema.Options.Builder builder) { + return builder + .setOption("field_option_boolean", FieldType.BOOLEAN, true) + .setOption("field_option_byte", FieldType.BYTE, (byte) 12) + .setOption("field_option_int16", FieldType.INT16, (short) 12) + .setOption("field_option_int32", FieldType.INT32, 12) + .setOption("field_option_int64", FieldType.INT64, 12L) + .setOption("field_option_string", FieldType.STRING, "foo"); + } + + private Schema.Options.Builder setOptionsSet2() { + return setOptionsSet2(Schema.Options.builder()); + } + + private void assertOptionSet1(Schema.Options options) { + assertEquals(true, options.getValue("field_option_boolean")); + assertEquals(12, (byte) options.getValue("field_option_byte")); + assertEquals(12, (short) options.getValue("field_option_int16")); + assertEquals(12, (int) options.getValue("field_option_int32")); + assertEquals(12L, (long) options.getValue("field_option_int64")); + assertEquals("foo", options.getValue("field_option_string")); + } + + private Schema.Options.Builder setOptionsSet2(Schema.Options.Builder builder) { + return builder + .setOption("field_option_bytes", FieldType.BYTES, new byte[] {0x42, 0x69, 0x00}) + .setOption("field_option_float", FieldType.FLOAT, (float) 12.0) + .setOption("field_option_double", FieldType.DOUBLE, 12.0) + .setOption("field_option_map", FieldType.map(FieldType.INT32, FieldType.STRING), TEST_MAP) + .setOption("field_option_array", FieldType.array(FieldType.STRING), TEST_LIST) + .setRowOption("field_option_row", TEST_ROW) + .setOption("field_option_value", FieldType.STRING, "other"); + } + + private void assertOptionSet2(Schema.Options options) { + assertArrayEquals(TEST_BYTES, options.getValue("field_option_bytes")); + assertEquals((float) 12.0, (float) options.getValue("field_option_float"), 0.1); + assertEquals(12.0, (double) options.getValue("field_option_double"), 0.1); + assertEquals(TEST_MAP, options.getValue("field_option_map")); + assertEquals(TEST_LIST, options.getValue("field_option_array")); + assertEquals(TEST_ROW, options.getValue("field_option_row")); + assertEquals("other", options.getValue("field_option_value")); + } + + @Test + public void testSchemaSetOptionWithBuilder() { + Schema schema = + Schema.builder() + .setOptions(setOptionsSet1(Schema.Options.builder())) + .addField(FIELD) + .build(); + assertOptionSet1(schema.getOptions()); + } + + @Test + public void testSchemaSetOption() { + Schema schema = + Schema.builder() + .setOptions(setOptionsSet1(Schema.Options.builder()).build()) + .addField(FIELD) + .build(); + assertOptionSet1(schema.getOptions()); + } + + @Test + public void testFieldWithOptionsBuilder() { + Schema schema = Schema.builder().addField(FIELD.withOptions(setOptionsSet1())).build(); + assertOptionSet1(schema.getField(FIELD_NAME).getOptions()); + } + + @Test + public void testFieldWithOptions() { + Schema schema = Schema.builder().addField(FIELD.withOptions(setOptionsSet1().build())).build(); + assertOptionSet1(schema.getField(FIELD_NAME).getOptions()); + } + + @Test + public void testFieldHasOptions() { + Schema schema = Schema.builder().addField(FIELD.withOptions(setOptionsSet1().build())).build(); + assertTrue(schema.getField(FIELD_NAME).getOptions().hasOptions()); + } + + @Test + public void testFieldHasNonOptions() { + Schema schema = Schema.builder().addField(FIELD).build(); + assertFalse(schema.getField(FIELD_NAME).getOptions().hasOptions()); + } + + @Test + public void testFieldHasOption() { + Schema schema = Schema.builder().addField(FIELD.withOptions(setOptionsSet1().build())).build(); + assertTrue(schema.getField(FIELD_NAME).getOptions().hasOption("field_option_byte")); + assertFalse(schema.getField(FIELD_NAME).getOptions().hasOption("foo_bar")); + } + + @Test + public void testFieldOptionNames() { + Schema schema = Schema.builder().addField(FIELD.withOptions(setOptionsSet1().build())).build(); + Set optionNames = schema.getField(FIELD_NAME).getOptions().getOptionNames(); + assertThat( + optionNames, + containsInAnyOrder( + "field_option_boolean", + "field_option_byte", + "field_option_int16", + "field_option_int32", + "field_option_int64", + "field_option_string")); + } + + @Test + public void testFieldBuilderSetOptions() { + Schema schema = + Schema.builder() + .addField(FIELD.toBuilder().setOptions(setOptionsSet1().build()).build()) + .build(); + assertOptionSet1(schema.getField(FIELD_NAME).getOptions()); + } + + @Test + public void testFieldBuilderSetOptionsBuilder() { + Schema schema = + Schema.builder().addField(FIELD.toBuilder().setOptions(setOptionsSet1()).build()).build(); + assertOptionSet1(schema.getField(FIELD_NAME).getOptions()); + } + + @Test + public void testAddOptions() { + Schema.Options options = + setOptionsSet1(Schema.Options.builder()) + .addOptions(setOptionsSet2(Schema.Options.builder()).build()) + .build(); + assertOptionSet1(options); + assertOptionSet2(options); + } +}