diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java index b581eecf4140..8d8da216afac 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java @@ -415,6 +415,7 @@ private static class PairWithRestrictionFn @Setup public void setup() { invoker = DoFnInvokers.invokerFor(fn); + invoker.invokeSetup(); } @ProcessElement @@ -422,6 +423,12 @@ public void processElement(ProcessContext context) { context.output( KV.of(context.element(), invoker.invokeGetInitialRestriction(context.element()))); } + + @Teardown + public void tearDown() { + invoker.invokeTeardown(); + invoker = null; + } } /** Splits the restriction using the given {@link SplitRestriction} method. */ @@ -439,6 +446,7 @@ private static class SplitRestrictionFn @Setup public void setup() { invoker = DoFnInvokers.invokerFor(splittableFn); + invoker.invokeSetup(); } @ProcessElement @@ -459,5 +467,11 @@ public void outputWithTimestamp(RestrictionT part, Instant timestamp) { } }); } + + @Teardown + public void tearDown() { + invoker.invokeTeardown(); + invoker = null; + } } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java index b7f0c10d046a..fe33b1ab5b8c 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java @@ -513,18 +513,19 @@ private enum State { private State state = State.BEFORE_SETUP; - @ProcessElement - public void processElement(ProcessContext c, OffsetRangeTracker tracker) { - assertEquals(State.INSIDE_BUNDLE, state); - assertTrue(tracker.tryClaim(0L)); - c.output(c.element()); - } - @GetInitialRestriction public OffsetRange getInitialRestriction(String value) { + assertEquals(State.OUTSIDE_BUNDLE, state); return new OffsetRange(0, 1); } + @SplitRestriction + public void splitRestriction( + String value, OffsetRange range, OutputReceiver receiver) { + assertEquals(State.OUTSIDE_BUNDLE, state); + receiver.output(range); + } + @Setup public void setUp() { assertEquals(State.BEFORE_SETUP, state); @@ -537,6 +538,13 @@ public void startBundle() { state = State.INSIDE_BUNDLE; } + @ProcessElement + public void processElement(ProcessContext c, OffsetRangeTracker tracker) { + assertEquals(State.INSIDE_BUNDLE, state); + assertTrue(tracker.tryClaim(0L)); + c.output(c.element()); + } + @FinishBundle public void finishBundle() { assertEquals(State.INSIDE_BUNDLE, state); @@ -553,12 +561,9 @@ public void tearDown() { @Test @Category({ValidatesRunner.class, UsesSplittableParDo.class}) public void testLifecycleMethods() throws Exception { - PCollection res = p.apply(Create.of("a", "b", "c")).apply(ParDo.of(new SDFWithLifecycle())); - PAssert.that(res).containsInAnyOrder("a", "b", "c"); - p.run(); }