diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java index ec51351a8783..ef47e1fdd15f 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/CloudObjects.java @@ -29,6 +29,7 @@ import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.coders.DoubleCoder; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.LengthPrefixCoder; @@ -49,6 +50,7 @@ private CloudObjects() {} ByteArrayCoder.class, KvCoder.class, VarLongCoder.class, + DoubleCoder.class, IntervalWindowCoder.class, IterableCoder.class, Timer.Coder.class, diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java index 215567e10797..26b721cdd060 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/util/CloudObjectsTest.java @@ -17,11 +17,12 @@ */ package org.apache.beam.runners.dataflow.util; +import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.emptyIterable; import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import java.io.IOException; @@ -32,7 +33,6 @@ import java.util.HashSet; import java.util.List; import java.util.Set; -import org.apache.beam.runners.core.construction.ModelCoderRegistrar; import org.apache.beam.runners.core.construction.SdkComponents; import org.apache.beam.sdk.coders.AvroCoder; import org.apache.beam.sdk.coders.ByteArrayCoder; @@ -50,7 +50,9 @@ import org.apache.beam.sdk.coders.SetCoder; import org.apache.beam.sdk.coders.StructuredCoder; import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.schemas.LogicalTypes; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.SchemaCoder; import org.apache.beam.sdk.transforms.join.CoGbkResult.CoGbkResultCoder; import org.apache.beam.sdk.transforms.join.CoGbkResultSchema; @@ -62,7 +64,9 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList.Builder; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.junit.Test; +import org.junit.experimental.runners.Enclosed; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.junit.runners.Parameterized; @@ -70,7 +74,22 @@ import org.junit.runners.Parameterized.Parameters; /** Tests for {@link CloudObjects}. */ +@RunWith(Enclosed.class) public class CloudObjectsTest { + private static final Schema TEST_SCHEMA = + Schema.builder() + .addBooleanField("bool") + .addByteField("int8") + .addInt16Field("int16") + .addInt32Field("int32") + .addInt64Field("int64") + .addFloatField("float") + .addDoubleField("double") + .addStringField("string") + .addArrayField("list_int32", FieldType.INT32) + .addLogicalTypeField("fixed_bytes", LogicalTypes.FixedBytes.of(4)) + .build(); + /** Tests that all of the Default Coders are tested. */ @RunWith(JUnit4.class) public static class DefaultsPresentTest { @@ -143,7 +162,8 @@ public static Iterable> data() { CoGbkResultSchema.of( ImmutableList.of(new TupleTag(), new TupleTag())), UnionCoder.of(ImmutableList.of(VarLongCoder.of(), ByteArrayCoder.of())))) - .add(SchemaCoder.of(Schema.builder().build())); + .add(SchemaCoder.of(Schema.builder().build())) + .add(SchemaCoder.of(TEST_SCHEMA)); for (Class atomicCoder : DefaultCoderCloudObjectTranslatorRegistrar.KNOWN_ATOMIC_CODERS) { dataBuilder.add(InstanceBuilder.ofType(atomicCoder).fromFactoryMethod("of").build()); @@ -177,21 +197,33 @@ public void toAndFromCloudObjectWithSdkComponents() throws Exception { private static void checkPipelineProtoCoderIds( Coder coder, CloudObject cloudObject, SdkComponents sdkComponents) throws Exception { - if (ModelCoderRegistrar.isKnownCoder(coder)) { + if (CloudObjects.DATAFLOW_KNOWN_CODERS.contains(coder.getClass())) { assertFalse(cloudObject.containsKey(PropertyNames.PIPELINE_PROTO_CODER_ID)); } else { assertTrue(cloudObject.containsKey(PropertyNames.PIPELINE_PROTO_CODER_ID)); assertEquals( sdkComponents.registerCoder(coder), - cloudObject.get(PropertyNames.PIPELINE_PROTO_CODER_ID)); + ((CloudObject) cloudObject.get(PropertyNames.PIPELINE_PROTO_CODER_ID)) + .get(PropertyNames.VALUE)); + } + List> expectedComponents; + if (coder instanceof StructuredCoder) { + expectedComponents = ((StructuredCoder) coder).getComponents(); + } else { + expectedComponents = coder.getCoderArguments(); } - List> coderArguments = coder.getCoderArguments(); Object cloudComponentsObject = cloudObject.get(PropertyNames.COMPONENT_ENCODINGS); - assertTrue(cloudComponentsObject instanceof List); - List cloudComponents = (List) cloudComponentsObject; - assertEquals(coderArguments.size(), cloudComponents.size()); - for (int i = 0; i < coderArguments.size(); i++) { - checkPipelineProtoCoderIds(coderArguments.get(i), cloudComponents.get(i), sdkComponents); + List cloudComponents; + if (cloudComponentsObject == null) { + cloudComponents = Lists.newArrayList(); + } else { + assertThat(cloudComponentsObject, instanceOf(List.class)); + cloudComponents = (List) cloudComponentsObject; + } + assertEquals(expectedComponents.size(), cloudComponents.size()); + for (int i = 0; i < expectedComponents.size(); i++) { + checkPipelineProtoCoderIds( + expectedComponents.get(i), cloudComponents.get(i), sdkComponents); } } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java index f6cfe6ac34ae..79faa915b824 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java @@ -24,6 +24,7 @@ import java.io.OutputStream; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.UUID; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -247,4 +248,21 @@ public String toString() { String string = "Schema: " + schema + " UUID: " + id + " delegateCoder: " + getDelegateCoder(); return string; } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RowCoder rowCoder = (RowCoder) o; + return schema.equals(rowCoder.schema); + } + + @Override + public int hashCode() { + return Objects.hash(schema); + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java index 0199534c3aba..9e06b4420fc0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java @@ -20,12 +20,12 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.util.Objects; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.coders.CustomCoder; import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.values.Row; /** {@link SchemaCoder} is used as the coder for types that have schemas registered. */ @@ -57,8 +57,7 @@ public static SchemaCoder of( /** Returns a {@link SchemaCoder} for {@link Row} classes. */ public static SchemaCoder of(Schema schema) { - return new SchemaCoder<>( - schema, SerializableFunctions.identity(), SerializableFunctions.identity()); + return new SchemaCoder<>(schema, identity(), identity()); } /** Returns the schema associated with this type. */ @@ -100,4 +99,47 @@ public boolean consistentWithEquals() { public String toString() { return "SchemaCoder: " + rowCoder.toString(); } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SchemaCoder that = (SchemaCoder) o; + return rowCoder.equals(that.rowCoder) + && toRowFunction.equals(that.toRowFunction) + && fromRowFunction.equals(that.fromRowFunction); + } + + @Override + public int hashCode() { + return Objects.hash(rowCoder, toRowFunction, fromRowFunction); + } + + private static RowIdentity identity() { + return new RowIdentity(); + } + + private static class RowIdentity implements SerializableFunction { + @Override + public Row apply(Row input) { + return input; + } + + @Override + public int hashCode() { + return Objects.hash(getClass()); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + return o != null && getClass() == o.getClass(); + } + } }