diff --git a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java new file mode 100644 index 000000000000..0a724cecd80c --- /dev/null +++ b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.operators.coordination.OperatorEvent; +import org.apache.iceberg.flink.sink.shuffle.statistics.DataStatistics; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; + +/** + * DataStatisticsEvent is sent between data statistics coordinator and operator to transmit data + * statistics + */ +@Internal +class DataStatisticsEvent implements OperatorEvent { + + private static final long serialVersionUID = 1L; + + private final long checkpointId; + private final DataStatistics dataStatistics; + + DataStatisticsEvent(long checkpointId, DataStatistics dataStatistics) { + this.checkpointId = checkpointId; + this.dataStatistics = dataStatistics; + } + + long checkpointId() { + return checkpointId; + } + + DataStatistics dataStatistics() { + return dataStatistics; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("checkpointId", checkpointId) + .add("dataStatistics", dataStatistics) + .toString(); + } +} diff --git a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java index 2582104de3c4..60f5e394b01f 100644 --- a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java +++ b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java @@ -34,6 +34,7 @@ import org.apache.iceberg.flink.sink.shuffle.statistics.DataStatistics; import org.apache.iceberg.flink.sink.shuffle.statistics.DataStatisticsFactory; import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; /** * DataStatisticsOperator collects traffic distribution statistics. A custom partitioner shall be @@ -97,8 +98,13 @@ public void open() throws Exception { } @Override - public void handleOperatorEvent(OperatorEvent evt) { - // TODO: receive event with aggregated statistics from coordinator and update globalStatistics + @SuppressWarnings("unchecked") + public void handleOperatorEvent(OperatorEvent event) { + Preconditions.checkArgument( + event instanceof DataStatisticsEvent, + "Received unexpected operator event " + event.getClass()); + globalStatistics = ((DataStatisticsEvent) event).dataStatistics(); + output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); } @Override @@ -126,8 +132,9 @@ public void snapshotState(StateSnapshotContext context) throws Exception { globalStatisticsState.add(globalStatistics); } - // TODO: send to coordinator - // For now we make it simple to send globalStatisticsState at checkpoint + // For now, we make it simple to send globalStatisticsState at checkpoint + operatorEventGateway.sendEventToCoordinator( + new DataStatisticsEvent<>(checkpointId, localStatistics)); // Recreate the local statistics localStatistics = statisticsFactory.createDataStatistics(); @@ -137,4 +144,9 @@ public void snapshotState(StateSnapshotContext context) throws Exception { DataStatistics localDataStatistics() { return localStatistics; } + + @VisibleForTesting + DataStatistics globalDataStatistics() { + return globalStatistics; + } } diff --git a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java index 6801cfcf720b..928a9f27cf9c 100644 --- a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java +++ b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java @@ -30,6 +30,7 @@ import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.CloseableRegistry; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.operators.coordination.MockOperatorEventGateway; import org.apache.flink.runtime.operators.testutils.MockEnvironment; @@ -46,6 +47,7 @@ import org.apache.flink.streaming.util.MockOutput; import org.apache.flink.streaming.util.MockStreamConfig; import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.iceberg.flink.sink.shuffle.statistics.DataStatistics; import org.apache.iceberg.flink.sink.shuffle.statistics.DataStatisticsFactory; import org.apache.iceberg.flink.sink.shuffle.statistics.MapDataStatistics; import org.apache.iceberg.flink.sink.shuffle.statistics.MapDataStatisticsFactory; @@ -71,6 +73,15 @@ private Environment getTestingEnvironment() { @Before public void before() throws Exception { + this.operator = createOperator(); + Environment env = getTestingEnvironment(); + this.operator.setup( + new OneInputStreamTask(env), + new MockStreamConfig(new Configuration(), 1), + new MockOutput<>(Lists.newArrayList())); + } + + private DataStatisticsOperator createOperator() { MockOperatorEventGateway mockGateway = new MockOperatorEventGateway(); KeySelector keySelector = new KeySelector() { @@ -82,13 +93,7 @@ public String getKey(String value) { } }; DataStatisticsFactory dataStatisticsFactory = new MapDataStatisticsFactory<>(); - - this.operator = new DataStatisticsOperator<>(keySelector, mockGateway, dataStatisticsFactory); - Environment env = getTestingEnvironment(); - this.operator.setup( - new OneInputStreamTask(env), - new MockStreamConfig(new Configuration(), 1), - new MockOutput<>(Lists.newArrayList())); + return new DataStatisticsOperator<>(keySelector, mockGateway, dataStatisticsFactory); } @After @@ -130,6 +135,70 @@ public void testOperatorOutput() throws Exception { } } + @Test + public void testRestoreState() throws Exception { + OperatorSubtaskState snapshot; + try (OneInputStreamOperatorTestHarness> + testHarness1 = createHarness(this.operator)) { + DataStatistics mapDataStatistics = new MapDataStatistics<>(); + mapDataStatistics.add("a"); + mapDataStatistics.add("a"); + mapDataStatistics.add("b"); + mapDataStatistics.add("c"); + operator.handleOperatorEvent(new DataStatisticsEvent<>(0, mapDataStatistics)); + assertTrue(operator.globalDataStatistics() instanceof MapDataStatistics); + assertEquals( + 2L, + (long) + ((MapDataStatistics) operator.globalDataStatistics()) + .dataStatistics() + .get("a")); + assertEquals( + 1L, + (long) + ((MapDataStatistics) operator.globalDataStatistics()) + .dataStatistics() + .get("b")); + assertEquals( + 1L, + (long) + ((MapDataStatistics) operator.globalDataStatistics()) + .dataStatistics() + .get("c")); + + snapshot = testHarness1.snapshot(1L, 0); + } + + // Use the snapshot to initialize state for another new operator and then verify that the global + // statistics for the new operator is same as before + DataStatisticsOperator restoredOperator = createOperator(); + try (OneInputStreamOperatorTestHarness> + testHarness2 = new OneInputStreamOperatorTestHarness<>(restoredOperator, 2, 2, 1)) { + + testHarness2.setup(); + testHarness2.initializeState(snapshot); + assertTrue(restoredOperator.globalDataStatistics() instanceof MapDataStatistics); + assertEquals( + 2L, + (long) + ((MapDataStatistics) restoredOperator.globalDataStatistics()) + .dataStatistics() + .get("a")); + assertEquals( + 1L, + (long) + ((MapDataStatistics) restoredOperator.globalDataStatistics()) + .dataStatistics() + .get("b")); + assertEquals( + 1L, + (long) + ((MapDataStatistics) restoredOperator.globalDataStatistics()) + .dataStatistics() + .get("c")); + } + } + private StateInitializationContext getStateContext() throws Exception { MockEnvironment env = new MockEnvironmentBuilder().build(); AbstractStateBackend abstractStateBackend = new HashMapStateBackend();