From 561df7c7f77264074a003a91a6deaf1a69999094 Mon Sep 17 00:00:00 2001 From: gang_ye Date: Fri, 21 Apr 2023 10:13:12 -0700 Subject: [PATCH] Backport to Flink 1.15: Implement DataStatisticsOperator operator to collect and send traffic distribution for guiding smart shuffling --- .../flink/sink/shuffle/DataStatistics.java | 53 +++++ .../sink/shuffle/DataStatisticsEvent.java | 57 +++++ .../sink/shuffle/DataStatisticsFactory.java | 33 +++ .../sink/shuffle/DataStatisticsOperator.java | 150 ++++++++++++ .../sink/shuffle/DataStatisticsOrRecord.java | 81 +++++++ .../flink/sink/shuffle/MapDataStatistics.java | 60 +++++ .../shuffle/MapDataStatisticsFactory.java | 34 +++ .../shuffle/TestDataStatisticsOperator.java | 217 ++++++++++++++++++ 8 files changed, 685 insertions(+) create mode 100644 flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java create mode 100644 flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java create mode 100644 flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsFactory.java create mode 100644 flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java create mode 100644 flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java create mode 100644 flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java create mode 100644 flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatisticsFactory.java create mode 100644 flink/v1.15/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java diff --git a/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java b/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java new file mode 100644 index 000000000000..cf6257f66d3c --- /dev/null +++ b/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import org.apache.flink.annotation.Internal; + +/** + * DataStatistics defines the interface to collect data distribution information. + * + *

Data statistics tracks traffic volume distribution across data keys. For low-cardinality key, + * a simple map of (key, count) can be used. For high-cardinality key, probabilistic data structures + * (sketching) can be used. + */ +@Internal +interface DataStatistics { + + /** + * Check if data statistics contains any statistics information + * + * @return true if data statistics doesn't contain any statistics information + */ + boolean isEmpty(); + + /** + * Add data key to data statistics. + * + * @param key generate from data by applying key selector + */ + void add(K key); + + /** + * Merge current statistics with other statistics + * + * @param otherStatistics the statistics to be merged + */ + void merge(DataStatistics otherStatistics); +} diff --git a/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java b/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java new file mode 100644 index 000000000000..5725188c333f --- /dev/null +++ b/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.operators.coordination.OperatorEvent; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; + +/** + * DataStatisticsEvent is sent between data statistics coordinator and operator to transmit data + * statistics + */ +@Internal +class DataStatisticsEvent implements OperatorEvent { + + private static final long serialVersionUID = 1L; + + private final long checkpointId; + private final DataStatistics dataStatistics; + + DataStatisticsEvent(long checkpointId, DataStatistics dataStatistics) { + this.checkpointId = checkpointId; + this.dataStatistics = dataStatistics; + } + + long checkpointId() { + return checkpointId; + } + + DataStatistics dataStatistics() { + return dataStatistics; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("checkpointId", checkpointId) + .add("dataStatistics", dataStatistics) + .toString(); + } +} diff --git a/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsFactory.java b/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsFactory.java new file mode 100644 index 000000000000..02f03aef5b82 --- /dev/null +++ b/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsFactory.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import org.apache.flink.annotation.Internal; + +/** + * DataStatisticsFactory defines the interface to create {@link DataStatistics}. + * + *

For low-cardinality key, MapDataStatisticsFactory will be implemented to create + * MapDataStatistics. + */ +@Internal +interface DataStatisticsFactory { + + DataStatistics createDataStatistics(); +} diff --git a/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java b/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java new file mode 100644 index 000000000000..a90bf7cdc728 --- /dev/null +++ b/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeHint; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.runtime.operators.coordination.OperatorEvent; +import org.apache.flink.runtime.operators.coordination.OperatorEventGateway; +import org.apache.flink.runtime.operators.coordination.OperatorEventHandler; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; + +/** + * DataStatisticsOperator collects traffic distribution statistics. A custom partitioner shall be + * attached to the DataStatisticsOperator output. The custom partitioner leverages the statistics to + * shuffle record to improve data clustering while maintaining relative balanced traffic + * distribution to downstream subtasks. + */ +class DataStatisticsOperator extends AbstractStreamOperator> + implements OneInputStreamOperator>, OperatorEventHandler { + private static final long serialVersionUID = 1L; + + // keySelector will be used to generate key from data for collecting data statistics + private final KeySelector keySelector; + private final OperatorEventGateway operatorEventGateway; + private final DataStatisticsFactory statisticsFactory; + private transient volatile DataStatistics localStatistics; + private transient volatile DataStatistics globalStatistics; + private transient ListState> globalStatisticsState; + + DataStatisticsOperator( + KeySelector keySelector, + OperatorEventGateway operatorEventGateway, + DataStatisticsFactory statisticsFactory) { + this.keySelector = keySelector; + this.operatorEventGateway = operatorEventGateway; + this.statisticsFactory = statisticsFactory; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + localStatistics = statisticsFactory.createDataStatistics(); + globalStatisticsState = + context + .getOperatorStateStore() + .getUnionListState( + new ListStateDescriptor<>( + "globalStatisticsState", + TypeInformation.of(new TypeHint>() {}))); + + if (context.isRestored()) { + int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); + if (globalStatisticsState.get() == null + || !globalStatisticsState.get().iterator().hasNext()) { + LOG.warn("Subtask {} doesn't have global statistics state to restore", subtaskIndex); + globalStatistics = statisticsFactory.createDataStatistics(); + } else { + LOG.info("Restoring global statistics state for subtask {}", subtaskIndex); + globalStatistics = globalStatisticsState.get().iterator().next(); + } + } else { + globalStatistics = statisticsFactory.createDataStatistics(); + } + } + + @Override + public void open() throws Exception { + if (!globalStatistics.isEmpty()) { + output.collect( + new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); + } + } + + @Override + @SuppressWarnings("unchecked") + public void handleOperatorEvent(OperatorEvent event) { + Preconditions.checkArgument( + event instanceof DataStatisticsEvent, + "Received unexpected operator event " + event.getClass()); + globalStatistics = ((DataStatisticsEvent) event).dataStatistics(); + output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); + } + + @Override + public void processElement(StreamRecord streamRecord) throws Exception { + T record = streamRecord.getValue(); + K key = keySelector.getKey(record); + localStatistics.add(key); + output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromRecord(record))); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + long checkpointId = context.getCheckpointId(); + int subTaskId = getRuntimeContext().getIndexOfThisSubtask(); + LOG.info( + "Taking data statistics operator snapshot for checkpoint {} in subtask {}", + checkpointId, + subTaskId); + + // Only subtask 0 saves the state so that globalStatisticsState(UnionListState) stores + // an exact copy of globalStatistics + if (!globalStatistics.isEmpty() && getRuntimeContext().getIndexOfThisSubtask() == 0) { + globalStatisticsState.clear(); + LOG.info("Saving global statistics {} to state in subtask {}", globalStatistics, subTaskId); + globalStatisticsState.add(globalStatistics); + } + + // For now, we make it simple to send globalStatisticsState at checkpoint + operatorEventGateway.sendEventToCoordinator( + new DataStatisticsEvent<>(checkpointId, localStatistics)); + + // Recreate the local statistics + localStatistics = statisticsFactory.createDataStatistics(); + } + + @VisibleForTesting + DataStatistics localDataStatistics() { + return localStatistics; + } + + @VisibleForTesting + DataStatistics globalDataStatistics() { + return globalStatistics; + } +} diff --git a/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java b/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java new file mode 100644 index 000000000000..4cbd1bb078a3 --- /dev/null +++ b/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.Serializable; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; + +/** + * The wrapper class for data statistics and record. It is the only way for data statistics operator + * to send global data statistics to custom partitioner to distribute data based on statistics + * + *

DataStatisticsOrRecord contains either data statistics(globally aggregated) or a record. It is + * sent from {@link DataStatisticsOperator} to partitioner. Once partitioner receives the data + * statistics, it will use that to decide the coming record should send to which writer subtask. + * After shuffling, a filter and mapper are required to filter out the data distribution weight, + * unwrap the object and extract the original record type T. + */ +class DataStatisticsOrRecord implements Serializable { + + private static final long serialVersionUID = 1L; + + private final DataStatistics statistics; + private final T record; + + private DataStatisticsOrRecord(T record, DataStatistics statistics) { + Preconditions.checkArgument( + record != null ^ statistics != null, + "A DataStatisticsOrRecord contain either statistics or record, not neither or both"); + this.statistics = statistics; + this.record = record; + } + + static DataStatisticsOrRecord fromRecord(T record) { + return new DataStatisticsOrRecord<>(record, null); + } + + static DataStatisticsOrRecord fromDataStatistics(DataStatistics statistics) { + return new DataStatisticsOrRecord<>(null, statistics); + } + + boolean hasDataStatistics() { + return statistics != null; + } + + boolean hasRecord() { + return record != null; + } + + DataStatistics dataStatistics() { + return statistics; + } + + T record() { + return record; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("statistics", statistics) + .add("record", record) + .toString(); + } +} diff --git a/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java b/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java new file mode 100644 index 000000000000..4f8fd9445b4b --- /dev/null +++ b/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Map; +import org.apache.flink.annotation.Internal; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; + +/** MapDataStatistics uses map to count key frequency */ +@Internal +class MapDataStatistics implements DataStatistics { + private final Map statistics = Maps.newHashMap(); + + @Override + public boolean isEmpty() { + return statistics.size() == 0; + } + + @Override + public void add(K key) { + // increase count of occurrence by one in the dataStatistics map + statistics.merge(key, 1L, Long::sum); + } + + @Override + public void merge(DataStatistics otherStatistics) { + Preconditions.checkArgument( + otherStatistics instanceof MapDataStatistics, + "Map statistics can not merge with " + otherStatistics.getClass()); + MapDataStatistics mapDataStatistic = (MapDataStatistics) otherStatistics; + mapDataStatistic.statistics.forEach((key, count) -> statistics.merge(key, count, Long::sum)); + } + + public Map dataStatistics() { + return statistics; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("statistics", statistics).toString(); + } +} diff --git a/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatisticsFactory.java b/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatisticsFactory.java new file mode 100644 index 000000000000..5c78f6fc14ad --- /dev/null +++ b/flink/v1.15/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatisticsFactory.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import org.apache.flink.annotation.Internal; + +/** + * MapDataStatisticsFactory creates {@link MapDataStatistics} to track traffic volume for + * low-cardinality key in hash mode + */ +@Internal +class MapDataStatisticsFactory implements DataStatisticsFactory { + + @Override + public DataStatistics createDataStatistics() { + return new MapDataStatistics<>(); + } +} diff --git a/flink/v1.15/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java b/flink/v1.15/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java new file mode 100644 index 000000000000..e8e8897d932c --- /dev/null +++ b/flink/v1.15/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.OperatorStateStore; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.CloseableRegistry; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.execution.Environment; +import org.apache.flink.runtime.operators.coordination.MockOperatorEventGateway; +import org.apache.flink.runtime.operators.testutils.MockEnvironment; +import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder; +import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; +import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateInitializationContextImpl; +import org.apache.flink.runtime.state.TestTaskStateManager; +import org.apache.flink.runtime.state.hashmap.HashMapStateBackend; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask; +import org.apache.flink.streaming.runtime.tasks.StreamMockEnvironment; +import org.apache.flink.streaming.util.MockOutput; +import org.apache.flink.streaming.util.MockStreamConfig; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestDataStatisticsOperator { + private DataStatisticsOperator operator; + + private Environment getTestingEnvironment() { + return new StreamMockEnvironment( + new Configuration(), + new Configuration(), + new ExecutionConfig(), + 1L, + new MockInputSplitProvider(), + 1, + new TestTaskStateManager()); + } + + @Before + public void before() throws Exception { + this.operator = createOperator(); + Environment env = getTestingEnvironment(); + this.operator.setup( + new OneInputStreamTask(env), + new MockStreamConfig(new Configuration(), 1), + new MockOutput<>(Lists.newArrayList())); + } + + private DataStatisticsOperator createOperator() { + MockOperatorEventGateway mockGateway = new MockOperatorEventGateway(); + KeySelector keySelector = + new KeySelector() { + private static final long serialVersionUID = 7662520075515707428L; + + @Override + public String getKey(String value) { + return value; + } + }; + DataStatisticsFactory dataStatisticsFactory = new MapDataStatisticsFactory<>(); + return new DataStatisticsOperator<>(keySelector, mockGateway, dataStatisticsFactory); + } + + @After + public void clean() throws Exception { + operator.close(); + } + + @Test + public void testProcessElement() throws Exception { + StateInitializationContext stateContext = getStateContext(); + operator.initializeState(stateContext); + operator.processElement(new StreamRecord<>("a")); + operator.processElement(new StreamRecord<>("a")); + operator.processElement(new StreamRecord<>("b")); + assertTrue(operator.localDataStatistics() instanceof MapDataStatistics); + MapDataStatistics mapDataStatistics = + (MapDataStatistics) operator.localDataStatistics(); + assertTrue(mapDataStatistics.dataStatistics().containsKey("a")); + assertTrue(mapDataStatistics.dataStatistics().containsKey("b")); + assertEquals(2L, (long) mapDataStatistics.dataStatistics().get("a")); + assertEquals(1L, (long) mapDataStatistics.dataStatistics().get("b")); + } + + @Test + public void testOperatorOutput() throws Exception { + try (OneInputStreamOperatorTestHarness> + testHarness = createHarness(this.operator)) { + testHarness.processElement(new StreamRecord<>("a")); + testHarness.processElement(new StreamRecord<>("b")); + testHarness.processElement(new StreamRecord<>("b")); + + List recordsOutput = + testHarness.extractOutputValues().stream() + .filter(DataStatisticsOrRecord::hasRecord) + .map(DataStatisticsOrRecord::record) + .collect(Collectors.toList()); + assertThat(recordsOutput) + .containsExactlyInAnyOrderElementsOf(ImmutableList.of("a", "b", "b")); + } + } + + @Test + public void testRestoreState() throws Exception { + OperatorSubtaskState snapshot; + try (OneInputStreamOperatorTestHarness> + testHarness1 = createHarness(this.operator)) { + DataStatistics mapDataStatistics = new MapDataStatistics<>(); + mapDataStatistics.add("a"); + mapDataStatistics.add("a"); + mapDataStatistics.add("b"); + mapDataStatistics.add("c"); + operator.handleOperatorEvent(new DataStatisticsEvent<>(0, mapDataStatistics)); + assertTrue(operator.globalDataStatistics() instanceof MapDataStatistics); + assertEquals( + 2L, + (long) + ((MapDataStatistics) operator.globalDataStatistics()) + .dataStatistics() + .get("a")); + assertEquals( + 1L, + (long) + ((MapDataStatistics) operator.globalDataStatistics()) + .dataStatistics() + .get("b")); + assertEquals( + 1L, + (long) + ((MapDataStatistics) operator.globalDataStatistics()) + .dataStatistics() + .get("c")); + + snapshot = testHarness1.snapshot(1L, 0); + } + + // Use the snapshot to initialize state for another new operator and then verify that the global + // statistics for the new operator is same as before + DataStatisticsOperator restoredOperator = createOperator(); + try (OneInputStreamOperatorTestHarness> + testHarness2 = new OneInputStreamOperatorTestHarness<>(restoredOperator, 2, 2, 1)) { + + testHarness2.setup(); + testHarness2.initializeState(snapshot); + assertTrue(restoredOperator.globalDataStatistics() instanceof MapDataStatistics); + assertEquals( + 2L, + (long) + ((MapDataStatistics) restoredOperator.globalDataStatistics()) + .dataStatistics() + .get("a")); + assertEquals( + 1L, + (long) + ((MapDataStatistics) restoredOperator.globalDataStatistics()) + .dataStatistics() + .get("b")); + assertEquals( + 1L, + (long) + ((MapDataStatistics) restoredOperator.globalDataStatistics()) + .dataStatistics() + .get("c")); + } + } + + private StateInitializationContext getStateContext() throws Exception { + MockEnvironment env = new MockEnvironmentBuilder().build(); + AbstractStateBackend abstractStateBackend = new HashMapStateBackend(); + CloseableRegistry cancelStreamRegistry = new CloseableRegistry(); + OperatorStateStore operatorStateStore = + abstractStateBackend.createOperatorStateBackend( + env, "test-operator", Collections.emptyList(), cancelStreamRegistry); + return new StateInitializationContextImpl(null, operatorStateStore, null, null, null); + } + + private OneInputStreamOperatorTestHarness> + createHarness(final DataStatisticsOperator dataStatisticsOperator) + throws Exception { + OneInputStreamOperatorTestHarness> harness = + new OneInputStreamOperatorTestHarness<>(dataStatisticsOperator, 1, 1, 0); + harness.setup(); + harness.open(); + return harness; + } +}