Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CustomCoder;
import org.apache.beam.sdk.coders.DoubleCoder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.LengthPrefixCoder;
Expand All @@ -49,6 +50,7 @@ private CloudObjects() {}
ByteArrayCoder.class,
KvCoder.class,
VarLongCoder.class,
DoubleCoder.class,
IntervalWindowCoder.class,
IterableCoder.class,
Timer.Coder.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
*/
package org.apache.beam.runners.dataflow.util;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.emptyIterable;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;

import java.io.IOException;
Expand All @@ -32,7 +33,6 @@
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.beam.runners.core.construction.ModelCoderRegistrar;
import org.apache.beam.runners.core.construction.SdkComponents;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.coders.ByteArrayCoder;
Expand All @@ -50,7 +50,9 @@
import org.apache.beam.sdk.coders.SetCoder;
import org.apache.beam.sdk.coders.StructuredCoder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.schemas.LogicalTypes;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.transforms.join.CoGbkResult.CoGbkResultCoder;
import org.apache.beam.sdk.transforms.join.CoGbkResultSchema;
Expand All @@ -62,15 +64,32 @@
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList.Builder;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.junit.Test;
import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;

/** Tests for {@link CloudObjects}. */
@RunWith(Enclosed.class)
public class CloudObjectsTest {
private static final Schema TEST_SCHEMA =
Schema.builder()
.addBooleanField("bool")
.addByteField("int8")
.addInt16Field("int16")
.addInt32Field("int32")
.addInt64Field("int64")
.addFloatField("float")
.addDoubleField("double")
.addStringField("string")
.addArrayField("list_int32", FieldType.INT32)
.addLogicalTypeField("fixed_bytes", LogicalTypes.FixedBytes.of(4))
.build();

/** Tests that all of the Default Coders are tested. */
@RunWith(JUnit4.class)
public static class DefaultsPresentTest {
Expand Down Expand Up @@ -143,7 +162,8 @@ public static Iterable<Coder<?>> data() {
CoGbkResultSchema.of(
ImmutableList.of(new TupleTag<Long>(), new TupleTag<byte[]>())),
UnionCoder.of(ImmutableList.of(VarLongCoder.of(), ByteArrayCoder.of()))))
.add(SchemaCoder.of(Schema.builder().build()));
.add(SchemaCoder.of(Schema.builder().build()))
.add(SchemaCoder.of(TEST_SCHEMA));
for (Class<? extends Coder> atomicCoder :
DefaultCoderCloudObjectTranslatorRegistrar.KNOWN_ATOMIC_CODERS) {
dataBuilder.add(InstanceBuilder.ofType(atomicCoder).fromFactoryMethod("of").build());
Expand Down Expand Up @@ -177,21 +197,33 @@ public void toAndFromCloudObjectWithSdkComponents() throws Exception {

private static void checkPipelineProtoCoderIds(
Coder<?> coder, CloudObject cloudObject, SdkComponents sdkComponents) throws Exception {
if (ModelCoderRegistrar.isKnownCoder(coder)) {
if (CloudObjects.DATAFLOW_KNOWN_CODERS.contains(coder.getClass())) {
assertFalse(cloudObject.containsKey(PropertyNames.PIPELINE_PROTO_CODER_ID));
} else {
assertTrue(cloudObject.containsKey(PropertyNames.PIPELINE_PROTO_CODER_ID));
assertEquals(
sdkComponents.registerCoder(coder),
cloudObject.get(PropertyNames.PIPELINE_PROTO_CODER_ID));
((CloudObject) cloudObject.get(PropertyNames.PIPELINE_PROTO_CODER_ID))
.get(PropertyNames.VALUE));
}
List<? extends Coder<?>> expectedComponents;
if (coder instanceof StructuredCoder) {
expectedComponents = ((StructuredCoder) coder).getComponents();
} else {
expectedComponents = coder.getCoderArguments();
}
List<? extends Coder<?>> coderArguments = coder.getCoderArguments();
Object cloudComponentsObject = cloudObject.get(PropertyNames.COMPONENT_ENCODINGS);
assertTrue(cloudComponentsObject instanceof List);
List<CloudObject> cloudComponents = (List<CloudObject>) cloudComponentsObject;
assertEquals(coderArguments.size(), cloudComponents.size());
for (int i = 0; i < coderArguments.size(); i++) {
checkPipelineProtoCoderIds(coderArguments.get(i), cloudComponents.get(i), sdkComponents);
List<CloudObject> cloudComponents;
if (cloudComponentsObject == null) {
cloudComponents = Lists.newArrayList();
} else {
assertThat(cloudComponentsObject, instanceOf(List.class));
cloudComponents = (List<CloudObject>) cloudComponentsObject;
}
assertEquals(expectedComponents.size(), cloudComponents.size());
for (int i = 0; i < expectedComponents.size(); i++) {
checkPipelineProtoCoderIds(
expectedComponents.get(i), cloudComponents.get(i), sdkComponents);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.io.OutputStream;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -247,4 +248,21 @@ public String toString() {
String string = "Schema: " + schema + " UUID: " + id + " delegateCoder: " + getDelegateCoder();
return string;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
RowCoder rowCoder = (RowCoder) o;
return schema.equals(rowCoder.schema);
}

@Override
public int hashCode() {
return Objects.hash(schema);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Objects;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
import org.apache.beam.sdk.coders.CustomCoder;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.values.Row;

/** {@link SchemaCoder} is used as the coder for types that have schemas registered. */
Expand Down Expand Up @@ -57,8 +57,7 @@ public static <T> SchemaCoder<T> of(

/** Returns a {@link SchemaCoder} for {@link Row} classes. */
public static SchemaCoder<Row> of(Schema schema) {
return new SchemaCoder<>(
schema, SerializableFunctions.identity(), SerializableFunctions.identity());
return new SchemaCoder<>(schema, identity(), identity());
}

/** Returns the schema associated with this type. */
Expand Down Expand Up @@ -100,4 +99,47 @@ public boolean consistentWithEquals() {
public String toString() {
return "SchemaCoder: " + rowCoder.toString();
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SchemaCoder<?> that = (SchemaCoder<?>) o;
return rowCoder.equals(that.rowCoder)
&& toRowFunction.equals(that.toRowFunction)
&& fromRowFunction.equals(that.fromRowFunction);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this just revert to object equality comparison on the to/from functions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah - I discussed this offline a bit with @kennknowles and he convinced me that it was better to have an equals function that might have some false negatives (if the toRowFunction and fromRowFunction don't have a good equals), rather than one that could have false positives (like if we rely on just checking the schema and typeDescriptor, and assume that the toRow/fromRow are the same).

I managed to make the CloudObjectsTest work by adding RowIdentity with an equals() function here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way I would phrase this is: let the functions own their equals. If they say they are equal, they are. If they say they aren't, they aren't. So this equals() is relative to that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good in theory. In practice these functions are usually lambdas, so we might have trouble making this work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. I was thinking it's not such a big deal to get false negatives when lambdas are used, since I really just want the equality check to use in tests.

What do you think about updating the various schema providers to create Function sub-classes (with equals implemented) instead of using lambdas?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another alternative could be to add something like assertEquivalentSchemaCoder that just checks schema and type, rather than continuing down this rabbit hole.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we go ahead and merge this as is? I could follow up with more changes to the SchemaCoder equals (plumbing through a type descriptor and using that for comparison, as well as possibly changing the toRow/fromRow functions created by the existing SchemaProviders to make them comparable)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a PR up now (#9493) that adds equals and hashCode to the fromRow and toRow functions created by all the GetterBasedSchemaProvider sub-classes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW this is not just for tests. The Flink runner appears to rely on coder equality (even though you can argue it shouldn't).

}

@Override
public int hashCode() {
return Objects.hash(rowCoder, toRowFunction, fromRowFunction);
}

private static RowIdentity identity() {
return new RowIdentity();
}

private static class RowIdentity implements SerializableFunction<Row, Row> {
@Override
public Row apply(Row input) {
return input;
}

@Override
public int hashCode() {
return Objects.hash(getClass());
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
return o != null && getClass() == o.getClass();
}
}
}