From fa16dc8b885caf11324465f241d0bf090d8a4239 Mon Sep 17 00:00:00 2001 From: Scott Wegner Date: Tue, 12 Apr 2016 10:08:37 -0700 Subject: [PATCH] Add display data to ParDo transforms --- .../DataflowPipelineTranslatorTest.java | 94 +++++++++++-------- .../inprocess/ForwardingPTransform.java | 6 ++ .../beam/sdk/transforms/DoFnReflector.java | 6 ++ .../beam/sdk/transforms/DoFnWithContext.java | 14 ++- .../apache/beam/sdk/transforms/Filter.java | 27 ++++++ .../beam/sdk/transforms/GroupByKey.java | 9 ++ .../IntraBundleParallelization.java | 9 ++ .../org/apache/beam/sdk/transforms/ParDo.java | 62 +++++++++--- .../apache/beam/sdk/transforms/Partition.java | 13 +++ .../inprocess/ForwardingPTransformTest.java | 10 ++ .../sdk/transforms/DoFnWithContextTest.java | 11 +++ .../beam/sdk/transforms/FilterTest.java | 20 ++++ .../beam/sdk/transforms/GroupByKeyTest.java | 15 +++ .../IntraBundleParallelizationTest.java | 26 +++++ .../apache/beam/sdk/transforms/ParDoTest.java | 74 ++++++++++++--- .../beam/sdk/transforms/PartitionTest.java | 13 +++ 16 files changed, 341 insertions(+), 68 deletions(-) diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineTranslatorTest.java index 1429e5a07b67..2bf6e2d91ed2 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineTranslatorTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/sdk/runners/DataflowPipelineTranslatorTest.java @@ -72,9 +72,9 @@ import com.google.api.services.dataflow.model.WorkerPool; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; -import org.hamcrest.Matcher; import org.hamcrest.Matchers; import org.junit.Assert; import org.junit.Rule; @@ -840,10 +840,12 @@ public void populateDisplayData(DisplayData.Builder builder) { } }; + ParDo.Bound parDo1 = ParDo.of(fn1); + ParDo.Bound parDo2 = ParDo.of(fn2); pipeline .apply(Create.of(1, 2, 3)) - .apply(ParDo.of(fn1)) - .apply(ParDo.of(fn2)); + .apply(parDo1) + .apply(parDo2); Job job = translator.translate( pipeline, pipeline.getRunner(), Collections.emptyList()).getJob(); @@ -855,43 +857,53 @@ public void populateDisplayData(DisplayData.Builder builder) { Map parDo2Properties = steps.get(2).getProperties(); assertThat(parDo1Properties, hasKey("display_data")); - - @SuppressWarnings("unchecked") - Collection> fn1displayData = - (Collection>) parDo1Properties.get("display_data"); - @SuppressWarnings("unchecked") - Collection> fn2displayData = - (Collection>) parDo2Properties.get("display_data"); - - @SuppressWarnings("unchecked") - Matcher>> fn1expectedData = - Matchers.>containsInAnyOrder( - ImmutableMap.builder() - .put("namespace", fn1.getClass().getName()) - .put("key", "foo") - .put("type", "STRING") - .put("value", "bar") - .build(), - ImmutableMap.builder() - .put("namespace", fn1.getClass().getName()) - .put("key", "foo2") - .put("type", "JAVA_CLASS") - .put("value", DataflowPipelineTranslatorTest.class.getName()) - .put("shortValue", DataflowPipelineTranslatorTest.class.getSimpleName()) - .put("label", "Test Class") - .put("linkUrl", "http://www.google.com") - .build()); - - @SuppressWarnings("unchecked") - Matcher>> fn2expectedData = - Matchers.>contains( - ImmutableMap.builder() - .put("namespace", fn2.getClass().getName()) - .put("key", "foo3") - .put("type", "INTEGER") - .put("value", 1234L) - .build()); - assertThat(fn1displayData, fn1expectedData); - assertThat(fn2displayData, fn2expectedData); + Collection> fn1displayData = + (Collection>) parDo1Properties.get("display_data"); + Collection> fn2displayData = + (Collection>) parDo2Properties.get("display_data"); + + ImmutableSet> expectedFn1DisplayData = ImmutableSet.of( + ImmutableMap.builder() + .put("key", "foo") + .put("type", "STRING") + .put("value", "bar") + .put("namespace", fn1.getClass().getName()) + .build(), + ImmutableMap.builder() + .put("key", "fn") + .put("type", "JAVA_CLASS") + .put("value", fn1.getClass().getName()) + .put("shortValue", fn1.getClass().getSimpleName()) + .put("namespace", parDo1.getClass().getName()) + .build(), + ImmutableMap.builder() + .put("key", "foo2") + .put("type", "JAVA_CLASS") + .put("value", DataflowPipelineTranslatorTest.class.getName()) + .put("shortValue", DataflowPipelineTranslatorTest.class.getSimpleName()) + .put("namespace", fn1.getClass().getName()) + .put("label", "Test Class") + .put("linkUrl", "http://www.google.com") + .build() + ); + + ImmutableSet> expectedFn2DisplayData = ImmutableSet.of( + ImmutableMap.builder() + .put("key", "fn") + .put("type", "JAVA_CLASS") + .put("value", fn2.getClass().getName()) + .put("shortValue", fn2.getClass().getSimpleName()) + .put("namespace", parDo2.getClass().getName()) + .build(), + ImmutableMap.builder() + .put("key", "foo3") + .put("type", "INTEGER") + .put("value", 1234L) + .put("namespace", fn2.getClass().getName()) + .build() + ); + + assertEquals(expectedFn1DisplayData, ImmutableSet.copyOf(fn1displayData)); + assertEquals(expectedFn2DisplayData, ImmutableSet.copyOf(fn2displayData)); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/ForwardingPTransform.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/ForwardingPTransform.java index 7833d424c54b..85aa1c4ca0e2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/ForwardingPTransform.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/inprocess/ForwardingPTransform.java @@ -20,6 +20,7 @@ import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.TypedPValue; @@ -53,4 +54,9 @@ public Coder getDefaultOutputCoder(InputT input, @SuppressWarnings("unuse TypedPValue output) throws CannotProvideCoderException { return delegate().getDefaultOutputCoder(input, output); } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + delegate().populateDisplayData(builder); + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java index 08c4391fd1c0..bbc022026af9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnReflector.java @@ -22,6 +22,7 @@ import org.apache.beam.sdk.transforms.DoFnWithContext.FinishBundle; import org.apache.beam.sdk.transforms.DoFnWithContext.ProcessElement; import org.apache.beam.sdk.transforms.DoFnWithContext.StartBundle; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.UserCodeException; @@ -653,6 +654,11 @@ protected TypeDescriptor getOutputTypeDescriptor() { return fn.getOutputTypeDescriptor(); } + @Override + public void populateDisplayData(DisplayData.Builder builder) { + fn.populateDisplayData(builder); + } + private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { in.defaultReadObject(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnWithContext.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnWithContext.java index 835730cdd7ca..7143626ba4db 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnWithContext.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnWithContext.java @@ -25,6 +25,8 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.DoFn.DelegatingAggregator; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.display.HasDisplayData; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowingInternals; @@ -82,7 +84,7 @@ * @param the type of the (main) output elements */ @Experimental -public abstract class DoFnWithContext implements Serializable { +public abstract class DoFnWithContext implements Serializable, HasDisplayData { /** Information accessible to all methods in this {@code DoFnWithContext}. */ public abstract class Context { @@ -414,4 +416,14 @@ public final Aggregator createAggregator( void prepareForProcessing() { aggregatorsAreFinal = true; } + + /** + * {@inheritDoc} + * + *

By default, does not register any display data. Implementors may override this method + * to provide their own display metadata. + */ + @Override + public void populateDisplayData(DisplayData.Builder builder) { + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Filter.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Filter.java index 547254d65a32..0e5e4a62cb2e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Filter.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Filter.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.transforms; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PCollection; /** @@ -99,9 +100,15 @@ public void processElement(ProcessContext c) { c.output(c.element()); } } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + Filter.populateDisplayData(builder, String.format("x < %s", value)); + } }); } + /** * Returns a {@code PTransform} that takes an input * {@code PCollection} and returns a {@code PCollection} with @@ -131,6 +138,11 @@ public void processElement(ProcessContext c) { c.output(c.element()); } } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + Filter.populateDisplayData(builder, String.format("x > %s", value)); + } }); } @@ -163,6 +175,11 @@ public void processElement(ProcessContext c) { c.output(c.element()); } } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + Filter.populateDisplayData(builder, String.format("x ≤ %s", value)); + } }); } @@ -195,6 +212,11 @@ public void processElement(ProcessContext c) { c.output(c.element()); } } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + Filter.populateDisplayData(builder, String.format("x ≥ %s", value)); + } }); } @@ -232,4 +254,9 @@ public void processElement(ProcessContext c) { protected Coder getDefaultOutputCoder(PCollection input) { return input.getCoder(); } + + private static void populateDisplayData( + DisplayData.Builder builder, String predicateDescription) { + builder.add("predicate", predicateDescription); + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java index 42c1f78962ff..1b3c4542d6d8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java @@ -21,6 +21,7 @@ import org.apache.beam.sdk.coders.Coder.NonDeterministicException; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.windowing.DefaultTrigger; import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.InvalidWindows; @@ -273,4 +274,12 @@ static Coder> getOutputValueCoder(Coder> inputCoder) public static KvCoder> getOutputKvCoder(Coder> inputCoder) { return KvCoder.of(getKeyCoder(inputCoder), getOutputValueCoder(inputCoder)); } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + if (fewKeys) { + builder.add("fewKeys", true); + } + } + } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/IntraBundleParallelization.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/IntraBundleParallelization.java index c66aa8d47c62..1b915629ab86 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/IntraBundleParallelization.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/IntraBundleParallelization.java @@ -20,6 +20,7 @@ import org.apache.beam.sdk.options.GcsOptions; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowingInternals; @@ -172,6 +173,14 @@ public PCollection apply(PCollection input) { return input.apply( ParDo.of(new MultiThreadedIntraBundleProcessingDoFn<>(doFn, maxParallelism))); } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .add("maxParallelism", maxParallelism) + .add("fn", doFn.getClass()) + .include(doFn); + } } /** diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index b448c266daf3..d266155b470c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -23,6 +23,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.runners.DirectPipelineRunner; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.display.DisplayData.Builder; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.DirectModeExecutionContext; @@ -556,7 +557,12 @@ public static UnboundMulti withOutputTags( * properties can be set on it first. */ public static Bound of(DoFn fn) { - return new Unbound().of(fn); + return of(fn, fn.getClass()); + } + + private static Bound of( + DoFn fn, Class fnClass) { + return new Unbound().of(fn, fnClass); } private static DoFn @@ -579,7 +585,7 @@ public static Bound of(DoFn */ @Experimental public static Bound of(DoFnWithContext fn) { - return of(adapt(fn)); + return of(adapt(fn), fn.getClass()); } /** @@ -666,9 +672,15 @@ public UnboundMulti withOutputTags(TupleTag mainOutp * still be specified. */ public Bound of(DoFn fn) { - return new Bound<>(name, sideInputs, fn); + return of(fn, fn.getClass()); + } + + private Bound of( + DoFn fn, Class fnClass) { + return new Bound<>(name, sideInputs, fn, fnClass); } + /** * Returns a new {@link ParDo} {@link PTransform} that's like this * transform but which will invoke the given {@link DoFnWithContext} @@ -678,7 +690,7 @@ public Bound of(DoFn fn) { * still be specified. */ public Bound of(DoFnWithContext fn) { - return of(adapt(fn)); + return of(adapt(fn), fn.getClass()); } } @@ -699,13 +711,16 @@ public static class Bound // Inherits name. private final List> sideInputs; private final DoFn fn; + private final Class fnClass; Bound(String name, List> sideInputs, - DoFn fn) { + DoFn fn, + Class fnClass) { super(name); this.sideInputs = sideInputs; this.fn = SerializableUtils.clone(fn); + this.fnClass = fnClass; } /** @@ -716,7 +731,7 @@ public static class Bound *

See the discussion of Naming above for more explanation. */ public Bound named(String name) { - return new Bound<>(name, sideInputs, fn); + return new Bound<>(name, sideInputs, fn, fnClass); } /** @@ -744,7 +759,7 @@ public Bound withSideInputs( ImmutableList.Builder> builder = ImmutableList.builder(); builder.addAll(this.sideInputs); builder.addAll(sideInputs); - return new Bound<>(name, builder.build(), fn); + return new Bound<>(name, builder.build(), fn, fnClass); } /** @@ -758,7 +773,7 @@ public Bound withSideInputs( public BoundMulti withOutputTags(TupleTag mainOutputTag, TupleTagList sideOutputTags) { return new BoundMulti<>( - name, sideInputs, mainOutputTag, sideOutputTags, fn); + name, sideInputs, mainOutputTag, sideOutputTags, fn, fnClass); } @Override @@ -799,7 +814,7 @@ protected String getKindString() { */ @Override public void populateDisplayData(Builder builder) { - builder.include(fn); + ParDo.populateDisplayData(builder, fn, fnClass); } public DoFn getFn() { @@ -891,8 +906,12 @@ public UnboundMulti withSideInputs( * more properties can still be specified. */ public BoundMulti of(DoFn fn) { + return of(fn, fn.getClass()); + } + + public BoundMulti of(DoFn fn, Class fnClass) { return new BoundMulti<>( - name, sideInputs, mainOutputTag, sideOutputTags, fn); + name, sideInputs, mainOutputTag, sideOutputTags, fn, fnClass); } /** @@ -904,7 +923,7 @@ public BoundMulti of(DoFn fn) { * more properties can still be specified. */ public BoundMulti of(DoFnWithContext fn) { - return of(adapt(fn)); + return of(adapt(fn), fn.getClass()); } } @@ -926,17 +945,20 @@ public static class BoundMulti private final TupleTag mainOutputTag; private final TupleTagList sideOutputTags; private final DoFn fn; + private final Class fnClass; BoundMulti(String name, List> sideInputs, TupleTag mainOutputTag, TupleTagList sideOutputTags, - DoFn fn) { + DoFn fn, + Class fnClass) { super(name); this.sideInputs = sideInputs; this.mainOutputTag = mainOutputTag; this.sideOutputTags = sideOutputTags; this.fn = SerializableUtils.clone(fn); + this.fnClass = fnClass; } /** @@ -948,7 +970,7 @@ public static class BoundMulti */ public BoundMulti named(String name) { return new BoundMulti<>( - name, sideInputs, mainOutputTag, sideOutputTags, fn); + name, sideInputs, mainOutputTag, sideOutputTags, fn, fnClass); } /** @@ -979,7 +1001,7 @@ public BoundMulti withSideInputs( builder.addAll(sideInputs); return new BoundMulti<>( name, builder.build(), - mainOutputTag, sideOutputTags, fn); + mainOutputTag, sideOutputTags, fn, fnClass); } @@ -1027,6 +1049,11 @@ protected String getKindString() { } } + @Override + public void populateDisplayData(Builder builder) { + ParDo.populateDisplayData(builder, fn, fnClass); + } + public DoFn getFn() { return fn; } @@ -1233,6 +1260,13 @@ private static SideInputReader makeSideInputReader( return DirectSideInputReader.of(sideInputValues); } + private static void populateDisplayData( + DisplayData.Builder builder, DoFn fn, Class fnClass) { + builder + .include(fn, fnClass) + .add("fn", fnClass); + } + /** * A {@code DoFnRunner.OutputManager} that provides facilities for checking output values for * illegal mutations. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Partition.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Partition.java index 47d49f7c38e4..5366fd0c6f83 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Partition.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Partition.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.transforms; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PCollectionTuple; @@ -121,6 +122,11 @@ public PCollectionList apply(PCollection in) { return pcs; } + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.include(partitionDoFn); + } + private final transient PartitionDoFn partitionDoFn; private Partition(PartitionDoFn partitionDoFn) { @@ -170,5 +176,12 @@ public void processElement(ProcessContext c) { partition + " not in [0.." + numPartitions + ")"); } } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder + .add("numPartitions", numPartitions) + .add("partitionFn", partitionFn.getClass()); + } } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/ForwardingPTransformTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/ForwardingPTransformTest.java index ca3753c9c3dc..366dfc5597f2 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/ForwardingPTransformTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/inprocess/ForwardingPTransformTest.java @@ -25,6 +25,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PCollection; import org.junit.Before; @@ -99,4 +100,13 @@ public void getDefaultOutputCoderDelegates() throws Exception { when(delegate.getDefaultOutputCoder(input, output)).thenReturn(outputCoder); assertThat(forwarding.getDefaultOutputCoder(input, output), equalTo(outputCoder)); } + + @Test + public void populateDisplayDataDelegates() { + DisplayData.Builder builder = mock(DisplayData.Builder.class); + doThrow(RuntimeException.class).when(delegate).populateDisplayData(builder); + + thrown.expect(RuntimeException.class); + forwarding.populateDisplayData(builder); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnWithContextTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnWithContextTest.java index 40c80b7780d5..391081adfd9d 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnWithContextTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/DoFnWithContextTest.java @@ -18,8 +18,10 @@ package org.apache.beam.sdk.transforms; import static org.hamcrest.CoreMatchers.isA; +import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -29,6 +31,7 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.Max.MaxIntegerFn; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.junit.Rule; import org.junit.Test; @@ -157,6 +160,14 @@ public void testDoFnWithContextUsingAggregators() { verify(agg).addValue(1L); } + @Test + public void testDefaultPopulateDisplayDataImplementation() { + DoFnWithContext fn = new DoFnWithContext() { + }; + DisplayData displayData = DisplayData.from(fn); + assertThat(displayData.items(), empty()); + } + @Test public void testCreateAggregatorInStartBundleThrows() { TestPipeline p = createTestPipeline(new DoFnWithContext() { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FilterTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FilterTest.java index f15f48e67c20..f58ba179a615 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FilterTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FilterTest.java @@ -17,9 +17,14 @@ */ package org.apache.beam.sdk.transforms; +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; + +import static org.hamcrest.MatcherAssert.assertThat; + import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.RunnableOnService; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PCollection; import org.junit.Test; @@ -158,4 +163,19 @@ public void testFilterGreaterThan() { PAssert.that(output).containsInAnyOrder(5, 6, 7); p.run(); } + + @Test + public void testDisplayData() { + ParDo.Bound lessThan = Filter.lessThan(123); + assertThat(DisplayData.from(lessThan), hasDisplayItem("predicate", "x < 123")); + + ParDo.Bound lessThanOrEqual = Filter.lessThanEq(234); + assertThat(DisplayData.from(lessThanOrEqual), hasDisplayItem("predicate", "x ≤ 234")); + + ParDo.Bound greaterThan = Filter.greaterThan(345); + assertThat(DisplayData.from(greaterThan), hasDisplayItem("predicate", "x > 345")); + + ParDo.Bound greaterThanOrEqual = Filter.greaterThanEq(456); + assertThat(DisplayData.from(greaterThanOrEqual), hasDisplayItem("predicate", "x ≥ 456")); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java index 1a7b0b7c0190..b84845ab162c 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java @@ -18,7 +18,9 @@ package org.apache.beam.sdk.transforms; import static org.apache.beam.sdk.TestUtils.KvMatcher.isKv; +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; import static org.hamcrest.core.Is.is; @@ -35,6 +37,7 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.RunnableOnService; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.InvalidWindows; import org.apache.beam.sdk.transforms.windowing.OutputTimeFns; @@ -375,4 +378,16 @@ public void processElement(ProcessContext c) throws Exception { public void testGroupByKeyGetName() { Assert.assertEquals("GroupByKey", GroupByKey.create().getName()); } + + @Test + public void testDisplayData() { + GroupByKey groupByKey = GroupByKey.create(); + GroupByKey groupByFewKeys = GroupByKey.create(true); + + DisplayData gbkDisplayData = DisplayData.from(groupByKey); + DisplayData fewKeysDisplayData = DisplayData.from(groupByFewKeys); + + assertThat(gbkDisplayData.items(), empty()); + assertThat(fewKeysDisplayData, hasDisplayItem("fewKeys", true)); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/IntraBundleParallelizationTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/IntraBundleParallelizationTest.java index dd0191923c49..80f6188c1cb6 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/IntraBundleParallelizationTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/IntraBundleParallelizationTest.java @@ -18,6 +18,8 @@ package org.apache.beam.sdk.transforms; import static org.apache.beam.sdk.testing.SystemNanoTimeSleeper.sleepMillis; +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includes; import static org.hamcrest.Matchers.both; import static org.hamcrest.Matchers.containsString; @@ -31,6 +33,7 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.junit.Before; import org.junit.Test; @@ -215,6 +218,29 @@ public void testIntraBundleParallelizationGetName() { IntraBundleParallelization.of(new DelayFn()).withMaxParallelism(1).getName()); } + @Test + public void testDisplayData() { + DoFn fn = new DoFn() { + @Override + public void processElement(ProcessContext c) throws Exception { + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add("foo", "bar"); + } + }; + + PTransform transform = IntraBundleParallelization + .withMaxParallelism(1234) + .of(fn); + + DisplayData displayData = DisplayData.from(transform); + assertThat(displayData, includes(fn)); + assertThat(displayData, hasDisplayItem("fn", fn.getClass())); + assertThat(displayData, hasDisplayItem("maxParallelism", 1234)); + } + /** * Runs the provided doFn inside of an {@link IntraBundleParallelization} transform. * diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 5724dd605911..44154e62e4f0 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -19,10 +19,13 @@ import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasKey; +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasType; +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includes; import static org.apache.beam.sdk.util.SerializableUtils.serializeToByteArray; import static org.apache.beam.sdk.util.StringUtils.byteArrayToJsonString; import static org.apache.beam.sdk.util.StringUtils.jsonStringToByteArray; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.isA; import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; import static org.hamcrest.collection.IsIterableContainingInOrder.contains; @@ -46,6 +49,7 @@ import org.apache.beam.sdk.transforms.ParDo.Bound; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.display.DisplayData.Builder; +import org.apache.beam.sdk.transforms.display.DisplayDataMatchers; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.util.IllegalMutationException; @@ -1525,20 +1529,66 @@ public void testMutatingInputCoderDoFnError() throws Exception { } @Test - public void testIncludesDoFnDisplayData() { - Bound parDo = - ParDo.of( - new DoFn() { - @Override - public void processElement(ProcessContext c) {} + public void testDoFnDisplayData() { + DoFn fn = new DoFn() { + @Override + public void processElement(ProcessContext c) { + } - @Override - public void populateDisplayData(Builder builder) { - builder.add("foo", "bar"); - } - }); + @Override + public void populateDisplayData(Builder builder) { + builder.add("doFnMetadata", "bar"); + } + }; + + Bound parDo = ParDo.of(fn); + + DisplayData displayData = DisplayData.from(parDo); + assertThat(displayData, hasDisplayItem(allOf( + hasKey("fn"), + hasType(DisplayData.Type.JAVA_CLASS), + DisplayDataMatchers.hasValue(fn.getClass().getName())))); + + assertThat(displayData, includes(fn)); + } + + @Test + public void testDoFnWithContextDisplayData() { + DoFnWithContext fn = new DoFnWithContext() { + @ProcessElement + public void proccessElement(ProcessContext c) {} + + @Override + public void populateDisplayData(Builder builder) { + builder.add("fnMetadata", "foobar"); + } + }; + + Bound parDo = ParDo.of(fn); + + DisplayData displayData = DisplayData.from(parDo); + assertThat(displayData, includes(fn)); + assertThat(displayData, hasDisplayItem("fn", fn.getClass())); + } + + @Test + public void testWithOutputTagsDisplayData() { + DoFnWithContext fn = new DoFnWithContext() { + @ProcessElement + public void proccessElement(ProcessContext c) {} + + @Override + public void populateDisplayData(Builder builder) { + builder.add("fnMetadata", "foobar"); + } + }; + + ParDo.BoundMulti parDo = ParDo + .withOutputTags(new TupleTag(), TupleTagList.empty()) + .of(fn); DisplayData displayData = DisplayData.from(parDo); - assertThat(displayData, hasDisplayItem(hasKey("foo"))); + assertThat(displayData, includes(fn)); + assertThat(displayData, hasDisplayItem("fn", fn.getClass())); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/PartitionTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/PartitionTest.java index dba6c1691cfc..608da0f7d591 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/PartitionTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/PartitionTest.java @@ -17,6 +17,9 @@ */ package org.apache.beam.sdk.transforms; +import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; + +import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -25,6 +28,7 @@ import org.apache.beam.sdk.testing.RunnableOnService; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Partition.PartitionFn; +import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; @@ -138,4 +142,13 @@ public void testDroppedPartition() { public void testPartitionGetName() { assertEquals("Partition", Partition.of(3, new ModFn()).getName()); } + + @Test + public void testDisplayData() { + Partition partition = Partition.of(123, new IdentityFn()); + DisplayData displayData = DisplayData.from(partition); + + assertThat(displayData, hasDisplayItem("numPartitions", 123)); + assertThat(displayData, hasDisplayItem("partitionFn", IdentityFn.class)); + } }