Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.flink.sink.shuffle;

import org.apache.flink.annotation.Internal;
import org.apache.flink.runtime.operators.coordination.OperatorEvent;
import org.apache.iceberg.flink.sink.shuffle.statistics.DataStatistics;
import org.apache.iceberg.relocated.com.google.common.base.MoreObjects;

/**
* DataStatisticsEvent is sent between data statistics coordinator and operator to transmit data
* statistics
*/
@Internal
class DataStatisticsEvent<K> implements OperatorEvent {

private static final long serialVersionUID = 1L;

private final long checkpointId;
private final DataStatistics<K> dataStatistics;

DataStatisticsEvent(long checkpointId, DataStatistics<K> dataStatistics) {
this.checkpointId = checkpointId;
this.dataStatistics = dataStatistics;
}

long checkpointId() {
return checkpointId;
}

DataStatistics<K> dataStatistics() {
return dataStatistics;
}

@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("checkpointId", checkpointId)
.add("dataStatistics", dataStatistics)
.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.iceberg.flink.sink.shuffle.statistics.DataStatistics;
import org.apache.iceberg.flink.sink.shuffle.statistics.DataStatisticsFactory;
import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;

/**
* DataStatisticsOperator collects traffic distribution statistics. A custom partitioner shall be
Expand Down Expand Up @@ -97,8 +98,13 @@ public void open() throws Exception {
}

@Override
public void handleOperatorEvent(OperatorEvent evt) {
// TODO: receive event with aggregated statistics from coordinator and update globalStatistics
@SuppressWarnings("unchecked")
public void handleOperatorEvent(OperatorEvent event) {
Preconditions.checkArgument(
event instanceof DataStatisticsEvent,
"Received unexpected operator event " + event.getClass());
globalStatistics = ((DataStatisticsEvent<K>) event).dataStatistics();
output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics)));
}

@Override
Expand Down Expand Up @@ -126,8 +132,9 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
globalStatisticsState.add(globalStatistics);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized one thing that I missed from last PR. It can be addressed with a separate PR. We don't want to use Kryo Java serialization for the DataStatistics. We need a stable parser (E.g. SimpleVersionedSerializer). You can find some example from IcebergEnumeratorStateSerializer.

You can find some more context from #1698.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will use a follow-up PR to address the serialization.

}

// TODO: send to coordinator
// For now we make it simple to send globalStatisticsState at checkpoint
// For now, we make it simple to send globalStatisticsState at checkpoint
operatorEventGateway.sendEventToCoordinator(
new DataStatisticsEvent<>(checkpointId, localStatistics));

// Recreate the local statistics
localStatistics = statisticsFactory.createDataStatistics();
Expand All @@ -137,4 +144,9 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
DataStatistics<K> localDataStatistics() {
return localStatistics;
}

@VisibleForTesting
DataStatistics<K> globalDataStatistics() {
return globalStatistics;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.operators.coordination.MockOperatorEventGateway;
import org.apache.flink.runtime.operators.testutils.MockEnvironment;
Expand All @@ -46,6 +47,7 @@
import org.apache.flink.streaming.util.MockOutput;
import org.apache.flink.streaming.util.MockStreamConfig;
import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
import org.apache.iceberg.flink.sink.shuffle.statistics.DataStatistics;
import org.apache.iceberg.flink.sink.shuffle.statistics.DataStatisticsFactory;
import org.apache.iceberg.flink.sink.shuffle.statistics.MapDataStatistics;
import org.apache.iceberg.flink.sink.shuffle.statistics.MapDataStatisticsFactory;
Expand All @@ -71,6 +73,15 @@ private Environment getTestingEnvironment() {

@Before
public void before() throws Exception {
this.operator = createOperator();
Environment env = getTestingEnvironment();
this.operator.setup(
new OneInputStreamTask<String, String>(env),
new MockStreamConfig(new Configuration(), 1),
new MockOutput<>(Lists.newArrayList()));
}

private DataStatisticsOperator<String, String> createOperator() {
MockOperatorEventGateway mockGateway = new MockOperatorEventGateway();
KeySelector<String, String> keySelector =
new KeySelector<String, String>() {
Expand All @@ -82,13 +93,7 @@ public String getKey(String value) {
}
};
DataStatisticsFactory<String> dataStatisticsFactory = new MapDataStatisticsFactory<>();

this.operator = new DataStatisticsOperator<>(keySelector, mockGateway, dataStatisticsFactory);
Environment env = getTestingEnvironment();
this.operator.setup(
new OneInputStreamTask<String, String>(env),
new MockStreamConfig(new Configuration(), 1),
new MockOutput<>(Lists.newArrayList()));
return new DataStatisticsOperator<>(keySelector, mockGateway, dataStatisticsFactory);
}

@After
Expand Down Expand Up @@ -130,6 +135,70 @@ public void testOperatorOutput() throws Exception {
}
}

@Test
public void testRestoreState() throws Exception {
OperatorSubtaskState snapshot;
try (OneInputStreamOperatorTestHarness<String, DataStatisticsOrRecord<String, String>>
testHarness1 = createHarness(this.operator)) {
DataStatistics<String> mapDataStatistics = new MapDataStatistics<>();
mapDataStatistics.add("a");
mapDataStatistics.add("a");
mapDataStatistics.add("b");
mapDataStatistics.add("c");
operator.handleOperatorEvent(new DataStatisticsEvent<>(0, mapDataStatistics));
assertTrue(operator.globalDataStatistics() instanceof MapDataStatistics);
assertEquals(
2L,
(long)
((MapDataStatistics<String>) operator.globalDataStatistics())
.dataStatistics()
.get("a"));
assertEquals(
1L,
(long)
((MapDataStatistics<String>) operator.globalDataStatistics())
.dataStatistics()
.get("b"));
assertEquals(
1L,
(long)
((MapDataStatistics<String>) operator.globalDataStatistics())
.dataStatistics()
.get("c"));

snapshot = testHarness1.snapshot(1L, 0);
}

// Use the snapshot to initialize state for another new operator and then verify that the global
// statistics for the new operator is same as before
DataStatisticsOperator<String, String> restoredOperator = createOperator();
try (OneInputStreamOperatorTestHarness<String, DataStatisticsOrRecord<String, String>>
testHarness2 = new OneInputStreamOperatorTestHarness<>(restoredOperator, 2, 2, 1)) {

testHarness2.setup();
testHarness2.initializeState(snapshot);
assertTrue(restoredOperator.globalDataStatistics() instanceof MapDataStatistics);
assertEquals(
2L,
(long)
((MapDataStatistics<String>) restoredOperator.globalDataStatistics())
.dataStatistics()
.get("a"));
assertEquals(
1L,
(long)
((MapDataStatistics<String>) restoredOperator.globalDataStatistics())
.dataStatistics()
.get("b"));
assertEquals(
1L,
(long)
((MapDataStatistics<String>) restoredOperator.globalDataStatistics())
.dataStatistics()
.get("c"));
}
}

private StateInitializationContext getStateContext() throws Exception {
MockEnvironment env = new MockEnvironmentBuilder().build();
AbstractStateBackend abstractStateBackend = new HashMapStateBackend();
Expand Down