From 3cf8cf31febc5ab36088d6ab2d7992f4bf35ad9f Mon Sep 17 00:00:00 2001 From: Luke Cwik Date: Fri, 7 Jan 2022 12:43:10 -0800 Subject: [PATCH] [BEAM-13614] Add OnWindowExpiration support to the Java SDK harness and proto translation. This implementation adds a timer family spec in the event time domain and adds the field to the ParDoPayload mentioning which timer family spec represents the on window expiration callback. --- .../src/main/proto/beam_runner_api.proto | 12 +- .../core/construction/ParDoTranslation.java | 86 ++++++-- .../core/construction/SplittableParDo.java | 11 +- .../construction/ParDoTranslationTest.java | 9 + .../dataflow/PrimitiveParDoSingleFactory.java | 51 +++-- .../control/RemoteExecutionTest.java | 32 ++- .../transforms/reflect/DoFnSignatures.java | 4 +- .../reflect/DoFnSignaturesTest.java | 29 +++ .../beam/fn/harness/FnApiDoFnRunner.java | 202 +++++++++++++++++- 9 files changed, 385 insertions(+), 51 deletions(-) diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto b/model/pipeline/src/main/proto/beam_runner_api.proto index aaeb9dabfe4e..4669063a81b7 100644 --- a/model/pipeline/src/main/proto/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/beam_runner_api.proto @@ -524,6 +524,12 @@ message ParDoPayload { // be placed in the pipeline requirements. bool requires_stable_input = 11; + // If populated, the name of the timer family spec which should be notified + // on each window expiry. + // If this is set, the corresponding standard requirement should also + // be placed in the pipeline requirements. + string on_window_expiration_timer_family_spec = 12; + reserved 6; } @@ -1601,7 +1607,7 @@ message StandardRunnerProtocols { // to be added in a forwards-compatible way). message StandardRequirements { enum Enum { - // This requirement indicates the state_spec and time_spec fields of ParDo + // This requirement indicates the state_specs and timer_family_specs fields of ParDo // transform payloads must be inspected. REQUIRES_STATEFUL_PROCESSING = 0 [(beam_urn) = "beam:requirement:pardo:stateful:v1"]; @@ -1620,6 +1626,10 @@ message StandardRequirements { // This requirement indicates the restriction_coder_id field of ParDo // transform payloads must be inspected. REQUIRES_SPLITTABLE_DOFN = 4 [(beam_urn) = "beam:requirement:pardo:splittable_dofn:v1"]; + + // This requirement indicates that the on_window_expiration_timer_family_spec field + // of ParDo transform payloads must be inspected. + REQUIRES_ON_WINDOW_EXPIRATION = 5 [(beam_urn) = "beam:requirement:pardo:on_window_expiration:v1"]; } } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java index 27f9ccd1c27a..2865f1216b48 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java @@ -30,6 +30,7 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; +import com.google.auto.value.AutoValue; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; @@ -38,6 +39,7 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import javax.annotation.Nullable; import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.model.pipeline.v1.RunnerApi.Components; import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; @@ -116,6 +118,9 @@ public class ParDoTranslation { */ public static final String REQUIRES_SPLITTABLE_DOFN_URN = "beam:requirement:pardo:splittable_dofn:v1"; + /** This requirement indicates that the ParDo requires a callback on each window expiration. */ + public static final String REQUIRES_ON_WINDOW_EXPIRATION_URN = + "beam:requirement:pardo:on_window_expiration:v1"; static { checkState( @@ -132,6 +137,9 @@ public class ParDoTranslation { checkState( REQUIRES_SPLITTABLE_DOFN_URN.equals( getUrn(StandardRequirements.Enum.REQUIRES_SPLITTABLE_DOFN))); + checkState( + REQUIRES_ON_WINDOW_EXPIRATION_URN.equals( + getUrn(StandardRequirements.Enum.REQUIRES_ON_WINDOW_EXPIRATION))); } /** The URN for an unknown Java {@link DoFn}. */ @@ -281,8 +289,7 @@ public Map translateStateSpecs(SdkComponents compon } @Override - public Map translateTimerFamilySpecs( - SdkComponents newComponents) { + public ParDoLikeTimerFamilySpecs translateTimerFamilySpecs(SdkComponents newComponents) { Map timerFamilySpecs = new HashMap<>(); for (Map.Entry timer : @@ -306,14 +313,34 @@ public Map translateTimerFamilySpecs( windowCoder); timerFamilySpecs.put(timerFamily.getKey(), spec); } - return timerFamilySpecs; + + String onWindowExpirationTimerFamilySpec = null; + if (signature.onWindowExpiration() != null) { + RunnerApi.TimerFamilySpec spec = + RunnerApi.TimerFamilySpec.newBuilder() + .setTimeDomain(translateTimeDomain(TimeDomain.EVENT_TIME)) + .setTimerFamilyCoderId( + registerCoderOrThrow(components, Timer.Coder.of(keyCoder, windowCoder))) + .build(); + for (int i = 0; i < Integer.MAX_VALUE; ++i) { + onWindowExpirationTimerFamilySpec = "onWindowExpiration" + i; + if (!timerFamilySpecs.containsKey(onWindowExpirationTimerFamilySpec)) { + break; + } + } + timerFamilySpecs.put(onWindowExpirationTimerFamilySpec, spec); + } + + return ParDoLikeTimerFamilySpecs.create( + timerFamilySpecs, onWindowExpirationTimerFamilySpec); } @Override public boolean isStateful() { return !signature.stateDeclarations().isEmpty() || !signature.timerDeclarations().isEmpty() - || !signature.timerFamilyDeclarations().isEmpty(); + || !signature.timerFamilyDeclarations().isEmpty() + || signature.onWindowExpiration() != null; } @Override @@ -645,7 +672,7 @@ static StateSpec fromProto(RunnerApi.StateSpec stateSpec, RehydratedComponent } } - private static String registerCoderOrThrow(SdkComponents components, Coder coder) { + public static String registerCoderOrThrow(SdkComponents components, Coder coder) { try { return components.registerCoder(coder); } catch (IOException exc) { @@ -665,7 +692,7 @@ public static RunnerApi.TimerFamilySpec translateTimerFamilySpec( .build(); } - private static RunnerApi.TimeDomain.Enum translateTimeDomain(TimeDomain timeDomain) { + public static RunnerApi.TimeDomain.Enum translateTimeDomain(TimeDomain timeDomain) { switch (timeDomain) { case EVENT_TIME: return RunnerApi.TimeDomain.Enum.EVENT_TIME; @@ -769,6 +796,22 @@ public static FunctionSpec translateWindowMappingFn( .build(); } + @AutoValue + public abstract static class ParDoLikeTimerFamilySpecs { + + public static ParDoLikeTimerFamilySpecs create( + Map timerFamilySpecs, + @Nullable String onWindowExpirationTimerFamilySpec) { + return new AutoValue_ParDoTranslation_ParDoLikeTimerFamilySpecs( + timerFamilySpecs, onWindowExpirationTimerFamilySpec); + } + + abstract Map timerFamilySpecs(); + + @Nullable + abstract String onWindowExpirationTimerFamilySpec(); + } + /** These methods drive to-proto translation from Java and from rehydrated ParDos. */ public interface ParDoLike { FunctionSpec translateDoFn(SdkComponents newComponents); @@ -778,7 +821,7 @@ public interface ParDoLike { Map translateStateSpecs(SdkComponents components) throws IOException; - Map translateTimerFamilySpecs(SdkComponents newComponents); + ParDoLikeTimerFamilySpecs translateTimerFamilySpecs(SdkComponents newComponents); boolean isStateful(); @@ -812,15 +855,24 @@ public static ParDoPayload payloadForParDoLike(ParDoLike parDo, SdkComponents co components.addRequirement(REQUIRES_TIME_SORTED_INPUT_URN); } - return ParDoPayload.newBuilder() - .setDoFn(parDo.translateDoFn(components)) - .putAllStateSpecs(parDo.translateStateSpecs(components)) - .putAllTimerFamilySpecs(parDo.translateTimerFamilySpecs(components)) - .putAllSideInputs(parDo.translateSideInputs(components)) - .setRequiresStableInput(parDo.isRequiresStableInput()) - .setRequiresTimeSortedInput(parDo.isRequiresTimeSortedInput()) - .setRestrictionCoderId(parDo.translateRestrictionCoderId(components)) - .setRequestsFinalization(parDo.requestsFinalization()) - .build(); + ParDoLikeTimerFamilySpecs timerFamilySpecs = parDo.translateTimerFamilySpecs(components); + ParDoPayload.Builder builder = + ParDoPayload.newBuilder() + .setDoFn(parDo.translateDoFn(components)) + .putAllStateSpecs(parDo.translateStateSpecs(components)) + .putAllTimerFamilySpecs(timerFamilySpecs.timerFamilySpecs()) + .putAllSideInputs(parDo.translateSideInputs(components)) + .setRequiresStableInput(parDo.isRequiresStableInput()) + .setRequiresTimeSortedInput(parDo.isRequiresTimeSortedInput()) + .setRestrictionCoderId(parDo.translateRestrictionCoderId(components)) + .setRequestsFinalization(parDo.requestsFinalization()); + + if (timerFamilySpecs.onWindowExpirationTimerFamilySpec() != null) { + components.addRequirement(REQUIRES_ON_WINDOW_EXPIRATION_URN); + builder.setOnWindowExpirationTimerFamilySpec( + timerFamilySpecs.onWindowExpirationTimerFamilySpec()); + } + + return builder.build(); } } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java index 0b9fe9547245..214f02f8e00d 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java @@ -29,9 +29,9 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload; import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput; import org.apache.beam.model.pipeline.v1.RunnerApi.StateSpec; -import org.apache.beam.model.pipeline.v1.RunnerApi.TimerFamilySpec; import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator; import org.apache.beam.runners.core.construction.ParDoTranslation.ParDoLike; +import org.apache.beam.runners.core.construction.ParDoTranslation.ParDoLikeTimerFamilySpecs; import org.apache.beam.runners.core.construction.ReadTranslation.BoundedReadPayloadTranslator; import org.apache.beam.runners.core.construction.ReadTranslation.UnboundedReadPayloadTranslator; import org.apache.beam.sdk.Pipeline; @@ -435,17 +435,16 @@ public Map translateStateSpecs(SdkComponents components) { } @Override - public Map translateTimerFamilySpecs( + public ParDoLikeTimerFamilySpecs translateTimerFamilySpecs( SdkComponents newComponents) { // SDFs don't have timers. - return ImmutableMap.of(); + return ParDoLikeTimerFamilySpecs.create(ImmutableMap.of(), null); } @Override public boolean isStateful() { - return !signature.stateDeclarations().isEmpty() - || !signature.timerDeclarations().isEmpty() - || !signature.timerFamilyDeclarations().isEmpty(); + // SDFs don't have state or timers. + return false; } @Override diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java index 057533354bb8..837c3aa65099 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java @@ -150,6 +150,12 @@ public void testToProto() throws Exception { assertEquals( parDo.getFn() instanceof StateTimerDropElementsFn, components.requirements().contains(ParDoTranslation.REQUIRES_STATEFUL_PROCESSING_URN)); + assertEquals( + parDo.getFn() instanceof StateTimerDropElementsFn, + components.requirements().contains(ParDoTranslation.REQUIRES_ON_WINDOW_EXPIRATION_URN)); + assertEquals( + parDo.getFn() instanceof StateTimerDropElementsFn ? "onWindowExpiration0" : "", + payload.getOnWindowExpirationTimerFamilySpec()); } @Test @@ -339,6 +345,9 @@ public void onEventTime(OnTimerContext context) {} @OnTimer(PROCESSING_TIMER_ID) public void onProcessingTime(OnTimerContext context) {} + @OnWindowExpiration + public void onWindowExpiration() {} + @Override public boolean equals(@Nullable Object other) { return other instanceof StateTimerDropElementsFn; diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java index 8b844278346d..40b3e3f2f0de 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java @@ -18,9 +18,10 @@ package org.apache.beam.runners.dataflow; import static org.apache.beam.runners.core.construction.PTransformTranslation.PAR_DO_TRANSFORM_URN; +import static org.apache.beam.runners.core.construction.ParDoTranslation.registerCoderOrThrow; +import static org.apache.beam.runners.core.construction.ParDoTranslation.translateTimeDomain; import static org.apache.beam.runners.core.construction.ParDoTranslation.translateTimerFamilySpec; import static org.apache.beam.sdk.transforms.reflect.DoFnSignatures.getStateSpecOrThrow; -import static org.apache.beam.sdk.transforms.reflect.DoFnSignatures.getTimerFamilySpecOrThrow; import static org.apache.beam.sdk.transforms.reflect.DoFnSignatures.getTimerSpecOrThrow; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; @@ -37,14 +38,17 @@ import org.apache.beam.runners.core.construction.PTransformReplacements; import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.ParDoTranslation; +import org.apache.beam.runners.core.construction.ParDoTranslation.ParDoLikeTimerFamilySpecs; import org.apache.beam.runners.core.construction.SdkComponents; import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory; +import org.apache.beam.runners.core.construction.Timer; import org.apache.beam.runners.core.construction.TransformPayloadTranslatorRegistrar; import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.PTransformOverrideFactory; +import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; import org.apache.beam.sdk.transforms.PTransform; @@ -54,6 +58,7 @@ import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.TimerDeclaration; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.values.PCollection; @@ -245,37 +250,59 @@ public Map translateStateSpecs(SdkComponents compon } @Override - public Map translateTimerFamilySpecs( + public ParDoLikeTimerFamilySpecs translateTimerFamilySpecs( SdkComponents newComponents) { Map timerFamilySpecs = new HashMap<>(); - for (Map.Entry timerFamily : - signature.timerFamilyDeclarations().entrySet()) { + + for (Map.Entry timer : + signature.timerDeclarations().entrySet()) { RunnerApi.TimerFamilySpec spec = translateTimerFamilySpec( - getTimerFamilySpecOrThrow(timerFamily.getValue(), doFn), + getTimerSpecOrThrow(timer.getValue(), doFn), newComponents, keyCoder, windowCoder); - timerFamilySpecs.put(timerFamily.getKey(), spec); + timerFamilySpecs.put(timer.getKey(), spec); } - for (Map.Entry timer : - signature.timerDeclarations().entrySet()) { + + for (Map.Entry timerFamily : + signature.timerFamilyDeclarations().entrySet()) { RunnerApi.TimerFamilySpec spec = translateTimerFamilySpec( - getTimerSpecOrThrow(timer.getValue(), doFn), + DoFnSignatures.getTimerFamilySpecOrThrow(timerFamily.getValue(), doFn), newComponents, keyCoder, windowCoder); - timerFamilySpecs.put(timer.getKey(), spec); + timerFamilySpecs.put(timerFamily.getKey(), spec); } - return timerFamilySpecs; + + String onWindowExpirationTimerFamilySpec = null; + if (signature.onWindowExpiration() != null) { + RunnerApi.TimerFamilySpec spec = + RunnerApi.TimerFamilySpec.newBuilder() + .setTimeDomain(translateTimeDomain(TimeDomain.EVENT_TIME)) + .setTimerFamilyCoderId( + registerCoderOrThrow(components, Timer.Coder.of(keyCoder, windowCoder))) + .build(); + for (int i = 0; i < Integer.MAX_VALUE; ++i) { + onWindowExpirationTimerFamilySpec = "onWindowExpiration" + i; + if (!timerFamilySpecs.containsKey(onWindowExpirationTimerFamilySpec)) { + break; + } + } + timerFamilySpecs.put(onWindowExpirationTimerFamilySpec, spec); + } + + return ParDoLikeTimerFamilySpecs.create( + timerFamilySpecs, onWindowExpirationTimerFamilySpec); } @Override public boolean isStateful() { return !signature.stateDeclarations().isEmpty() || !signature.timerDeclarations().isEmpty() - || !signature.timerFamilyDeclarations().isEmpty(); + || !signature.timerFamilyDeclarations().isEmpty() + || signature.onWindowExpiration() != null; } @Override diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java index 30a780c426a0..24e5441f8697 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java @@ -1622,9 +1622,10 @@ public void processElement( @OnTimer("event") public void eventTimer( OnTimerContext context, + @Key String key, @TimerId("event") Timer eventTimeTimer, @TimerId("processing") Timer processingTimeTimer) { - context.output(KV.of("event", "")); + context.output(KV.of("event", key)); eventTimeTimer .withOutputTimestamp(context.timestamp()) .set(context.fireTimestamp().plus(Duration.millis(11L))); @@ -1635,15 +1636,22 @@ public void eventTimer( @OnTimer("processing") public void processingTimer( OnTimerContext context, + @Key String key, @TimerId("event") Timer eventTimeTimer, @TimerId("processing") Timer processingTimeTimer) { - context.output(KV.of("processing", "")); + context.output(KV.of("processing", key)); eventTimeTimer .withOutputTimestamp(context.timestamp()) .set(context.fireTimestamp().plus(Duration.millis(21L))); processingTimeTimer.offset(Duration.millis(22L)); processingTimeTimer.setRelative(); } + + @OnWindowExpiration + public void onWindowExpiration( + @Key String key, OutputReceiver> outputReceiver) { + outputReceiver.output(KV.of("onWindowExpiration", key)); + } })) // Force the output to be materialized .apply("gbk", GroupByKey.create()); @@ -1702,10 +1710,13 @@ public void processingTimer( ProcessBundleDescriptors.TimerSpec eventTimerSpec = null; ProcessBundleDescriptors.TimerSpec processingTimerSpec = null; + ProcessBundleDescriptors.TimerSpec onWindowExpirationSpec = null; for (Map timerSpecs : descriptor.getTimerSpecs().values()) { for (ProcessBundleDescriptors.TimerSpec timerSpec : timerSpecs.values()) { - if (TimeDomain.EVENT_TIME.equals(timerSpec.getTimerSpec().getTimeDomain())) { + if ("onWindowExpiration0".equals(timerSpec.timerId())) { + onWindowExpirationSpec = timerSpec; + } else if (TimeDomain.EVENT_TIME.equals(timerSpec.getTimerSpec().getTimeDomain())) { eventTimerSpec = timerSpec; } else if (TimeDomain.PROCESSING_TIME.equals(timerSpec.getTimerSpec().getTimeDomain())) { processingTimerSpec = timerSpec; @@ -1737,6 +1748,12 @@ public void processingTimer( .getTimerReceivers() .get(KV.of(processingTimerSpec.transformId(), processingTimerSpec.timerId())) .accept(timerForTest("Z", 2000L, 200L)); + bundle + .getTimerReceivers() + .get(KV.of(onWindowExpirationSpec.transformId(), onWindowExpirationSpec.timerId())) + // Normally fireTimestamp and holdTimestamp would be the same in window expirations but + // we specifically set them to different values to ensure that they are used correctly. + .accept(timerForTest("key", 5001L, 5000L)); } String mainOutputTransform = Iterables.getOnlyElement(descriptor.getRemoteOutputCoders().keySet()); @@ -1745,11 +1762,14 @@ public void processingTimer( containsInAnyOrder( valueInGlobalWindow(KV.of("mainX", "")), WindowedValue.timestampedValueInGlobalWindow( - KV.of("event", ""), + KV.of("event", "Y"), BoundedWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(100L))), WindowedValue.timestampedValueInGlobalWindow( - KV.of("processing", ""), - BoundedWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(200L))))); + KV.of("processing", "Z"), + BoundedWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(200L))), + WindowedValue.timestampedValueInGlobalWindow( + KV.of("onWindowExpiration", "key"), + BoundedWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(5000L))))); assertThat( timerValues.get(KV.of(eventTimerSpec.transformId(), eventTimerSpec.timerId())), containsInAnyOrder( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index f28493ad6a21..83859ae56182 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -2510,7 +2510,9 @@ public static boolean usesWatermarkHold(DoFn doFn) { } public static boolean usesTimers(DoFn doFn) { - return signatureForDoFn(doFn).usesTimers() || requiresTimeSortedInput(doFn); + return signatureForDoFn(doFn).usesTimers() + || requiresTimeSortedInput(doFn) + || signatureForDoFn(doFn).onWindowExpiration() != null; } public static boolean usesState(DoFn doFn) { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java index 9166a33c1459..1c6a32779911 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java @@ -1374,6 +1374,34 @@ public void test() { } } + /** + * It is important that we don't add any state/timers to this class to ensure that statefulness is + * detected from {@link DoFn.OnWindowExpiration @OnWindowExpiration} only. + */ + private static class StatefulWithOnWindowExpiration extends DoFn, String> + implements FeatureTest { + + @ProcessElement + public void process(@Element KV input) {} + + @OnWindowExpiration + public void onWindowExpiration() {} + + @Override + public void test() { + assertThat(DoFnSignatures.isSplittable(this), SerializableMatchers.equalTo(false)); + assertThat(DoFnSignatures.isStateful(this), SerializableMatchers.equalTo(true)); + assertThat(DoFnSignatures.usesTimers(this), SerializableMatchers.equalTo(true)); + assertThat(DoFnSignatures.usesState(this), SerializableMatchers.equalTo(false)); + assertThat(DoFnSignatures.usesBagState(this), SerializableMatchers.equalTo(false)); + assertThat(DoFnSignatures.usesMapState(this), SerializableMatchers.equalTo(false)); + assertThat(DoFnSignatures.usesSetState(this), SerializableMatchers.equalTo(false)); + assertThat(DoFnSignatures.usesValueState(this), SerializableMatchers.equalTo(false)); + assertThat(DoFnSignatures.usesWatermarkHold(this), SerializableMatchers.equalTo(false)); + assertThat(DoFnSignatures.requiresTimeSortedInput(this), SerializableMatchers.equalTo(false)); + } + } + private static class StatefulWithTimers extends DoFn, String> implements FeatureTest { @@ -1553,6 +1581,7 @@ public void test() { Lists.newArrayList( new StatelessDoFn(), new StatefulWithValueState(), + new StatefulWithOnWindowExpiration(), new StatefulWithTimers(), new StatefulWithTimersAndValueState(), new StatefulWithSetState(), diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java index 14eae0fa0775..aa4cb6486403 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java @@ -197,8 +197,14 @@ static class Factory> coder = entry.getValue().getValue(); - context.addIncomingTimerEndpoint( - localName, coder, timer -> runner.processTimer(localName, timeDomain, timer)); + if (!localName.equals("") + && localName.equals(runner.parDoPayload.getOnWindowExpirationTimerFamilySpec())) { + context.addIncomingTimerEndpoint( + localName, coder, timer -> runner.processOnWindowExpiration(timer)); + } else { + context.addIncomingTimerEndpoint( + localName, coder, timer -> runner.processTimer(localName, timeDomain, timer)); + } } return runner; } @@ -238,6 +244,7 @@ static class Factory onTimerContext; + private final OnWindowExpirationContext onWindowExpirationContext; private final FinishBundleArgumentProvider finishBundleArgumentProvider; /** @@ -279,14 +286,14 @@ static class Factory currentWatermarkEstimator; @@ -306,10 +316,15 @@ static class Factory currentTracker; - /** Only valid during {@link #processTimer}, null otherwise. */ + /** + * Only valid during {@link #processTimer} and {@link #processOnWindowExpiration}, null otherwise. + */ private Timer currentTimer; /** Only valid during {@link #processTimer}, null otherwise. */ @@ -462,6 +477,7 @@ private interface TriFunction { this.splitListener = splitListener; this.bundleFinalizer = bundleFinalizer; this.onTimerContext = new OnTimerContext(); + this.onWindowExpirationContext = new OnWindowExpirationContext<>(); try { this.mainInputId = ParDoTranslation.getMainInputName(pTransform); @@ -1720,6 +1736,23 @@ private void processTimerDirect( doFnInvoker.invokeOnTimer(timerId, timerFamilyId, onTimerContext); } + private void processOnWindowExpiration(Timer timer) { + try { + currentKey = timer.getUserKey(); + currentTimer = timer; + Iterator windowIterator = + (Iterator) timer.getWindows().iterator(); + while (windowIterator.hasNext()) { + currentWindow = windowIterator.next(); + doFnInvoker.invokeOnWindowExpiration(onWindowExpirationContext); + } + } finally { + currentKey = null; + currentTimer = null; + currentWindow = null; + } + } + private void finishBundle() throws Exception { timerBundleTracker.outputTimers(timerFamilyOrId -> outboundTimerReceivers.get(timerFamilyOrId)); for (CloseableFnDataReceiver outboundTimerReceiver : outboundTimerReceivers.values()) { @@ -2508,6 +2541,159 @@ public WatermarkEstimator watermarkEstimator() { } } + /** + * Provides arguments for a {@link DoFnInvoker} for {@link + * DoFn.OnWindowExpiration @OnWindowExpiration}. + */ + private class OnWindowExpirationContext extends BaseArgumentProvider { + private class Context extends DoFn.OnWindowExpirationContext { + private Context() { + doFn.super(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return pipelineOptions; + } + + @Override + public BoundedWindow window() { + return currentWindow; + } + + @Override + public void output(OutputT output) { + outputTo( + mainOutputConsumers, + WindowedValue.of( + output, currentTimer.getHoldTimestamp(), currentWindow, currentTimer.getPane())); + } + + @Override + public void outputWithTimestamp(OutputT output, Instant timestamp) { + checkOnWindowExpirationTimestamp(timestamp); + outputTo( + mainOutputConsumers, + WindowedValue.of(output, timestamp, currentWindow, currentTimer.getPane())); + } + + @Override + public void output(TupleTag tag, T output) { + Collection>> consumers = + (Collection) localNameToConsumer.get(tag.getId()); + if (consumers == null) { + throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); + } + outputTo( + consumers, + WindowedValue.of( + output, currentTimer.getHoldTimestamp(), currentWindow, currentTimer.getPane())); + } + + @Override + public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp) { + checkOnWindowExpirationTimestamp(timestamp); + Collection>> consumers = + (Collection) localNameToConsumer.get(tag.getId()); + if (consumers == null) { + throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); + } + outputTo( + consumers, WindowedValue.of(output, timestamp, currentWindow, currentTimer.getPane())); + } + + @SuppressWarnings( + "deprecation") // Allowed Skew is deprecated for users, but must be respected + private void checkOnWindowExpirationTimestamp(Instant timestamp) { + Instant lowerBound; + try { + lowerBound = currentTimer.getHoldTimestamp().minus(doFn.getAllowedTimestampSkew()); + } catch (ArithmeticException e) { + lowerBound = BoundedWindow.TIMESTAMP_MIN_VALUE; + } + if (timestamp.isBefore(lowerBound) + || timestamp.isAfter(BoundedWindow.TIMESTAMP_MAX_VALUE)) { + throw new IllegalArgumentException( + String.format( + "Cannot output with timestamp %s. Output timestamps must be no earlier than the " + + "timestamp of the timer (%s) minus the allowed skew (%s) and no later " + + "than %s. See the DoFn#getAllowedTimestampSkew() Javadoc for details on " + + "changing the allowed skew.", + timestamp, + currentTimer.getHoldTimestamp(), + PeriodFormat.getDefault().print(doFn.getAllowedTimestampSkew().toPeriod()), + BoundedWindow.TIMESTAMP_MAX_VALUE)); + } + } + } + + private final OnWindowExpirationContext.Context context = + new OnWindowExpirationContext.Context(); + + @Override + public BoundedWindow window() { + return currentWindow; + } + + @Override + public Instant timestamp(DoFn doFn) { + return currentTimer.getHoldTimestamp(); + } + + @Override + public TimeDomain timeDomain(DoFn doFn) { + return currentTimeDomain; + } + + @Override + public K key() { + return (K) currentTimer.getUserKey(); + } + + @Override + public OutputReceiver outputReceiver(DoFn doFn) { + return DoFnOutputReceivers.windowedReceiver(context, null); + } + + @Override + public OutputReceiver outputRowReceiver(DoFn doFn) { + return DoFnOutputReceivers.rowReceiver(context, null, mainOutputSchemaCoder); + } + + @Override + public MultiOutputReceiver taggedOutputReceiver(DoFn doFn) { + return DoFnOutputReceivers.windowedMultiReceiver(context); + } + + @Override + public State state(String stateId, boolean alwaysFetched) { + StateDeclaration stateDeclaration = doFnSignature.stateDeclarations().get(stateId); + checkNotNull(stateDeclaration, "No state declaration found for %s", stateId); + StateSpec spec; + try { + spec = (StateSpec) stateDeclaration.field().get(doFn); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + State state = spec.bind(stateId, stateAccessor); + if (alwaysFetched) { + return (State) ((ReadableState) state).readLater(); + } else { + return state; + } + } + + @Override + public PipelineOptions pipelineOptions() { + return pipelineOptions; + } + + @Override + public String getErrorContext() { + return "FnApiDoFnRunner/OnWindowExpiration"; + } + } + /** Provides arguments for a {@link DoFnInvoker} for {@link DoFn.OnTimer @OnTimer}. */ private class OnTimerContext extends BaseArgumentProvider {