diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/ReaderInvocationUtil.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/ReaderInvocationUtil.java index 736a2dd9da59..1b52354bf01a 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/ReaderInvocationUtil.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/metrics/ReaderInvocationUtil.java @@ -22,6 +22,7 @@ import org.apache.beam.runners.core.metrics.MetricsContainerImpl; import org.apache.beam.runners.flink.FlinkPipelineOptions; import org.apache.beam.sdk.io.Source; +import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.options.PipelineOptions; @@ -69,4 +70,18 @@ public boolean invokeAdvance(ReaderT reader) throws IOException { return reader.advance(); } } + + public UnboundedSource.CheckpointMark invokeCheckpointMark( + UnboundedSource.UnboundedReader reader) throws IOException { + if (enableMetrics) { + try (Closeable ignored = + MetricsEnvironment.scopedMetricsContainer(container.getMetricsContainer(stepName))) { + UnboundedSource.CheckpointMark result = reader.getCheckpointMark(); + container.updateMetrics(stepName); + return result; + } + } else { + return reader.getCheckpointMark(); + } + } } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java index 92d0652e11f8..cd187b4da02d 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java @@ -391,6 +391,9 @@ public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throw return; } + ReaderInvocationUtil> readerInvoker = + new ReaderInvocationUtil<>(stepName, serializedOptions.get(), metricContainer); + stateForCheckpoint.clear(); long checkpointId = functionSnapshotContext.getCheckpointId(); @@ -405,7 +408,7 @@ public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throw UnboundedSource.UnboundedReader reader = localReaders.get(i); @SuppressWarnings("unchecked") - CheckpointMarkT mark = (CheckpointMarkT) reader.getCheckpointMark(); + CheckpointMarkT mark = (CheckpointMarkT) readerInvoker.invokeCheckpointMark(reader); checkpointMarks.add(mark); KV, CheckpointMarkT> kv = KV.of(source, mark); stateForCheckpoint.add(kv);