diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java index 4c91874e63f6..ce78f33c14cb 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java @@ -29,6 +29,7 @@ import org.apache.beam.fn.harness.control.HarnessMonitoringInfosInstructionHandler; import org.apache.beam.fn.harness.control.ProcessBundleHandler; import org.apache.beam.fn.harness.data.BeamFnDataGrpcClient; +import org.apache.beam.fn.harness.debug.DataSampler; import org.apache.beam.fn.harness.logging.BeamFnLoggingClient; import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache; import org.apache.beam.fn.harness.status.BeamFnStatusClient; @@ -89,6 +90,7 @@ public class FnHarness { private static final String STATUS_API_SERVICE_DESCRIPTOR = "STATUS_API_SERVICE_DESCRIPTOR"; private static final String PIPELINE_OPTIONS = "PIPELINE_OPTIONS"; private static final String RUNNER_CAPABILITIES = "RUNNER_CAPABILITIES"; + private static final String ENABLE_DATA_SAMPLING_EXPERIMENT = "enable_data_sampling"; private static final Logger LOG = LoggerFactory.getLogger(FnHarness.class); private static Endpoints.ApiServiceDescriptor getApiServiceDescriptor(String descriptor) @@ -221,6 +223,7 @@ public static void main( options.as(ExecutorOptions.class).getScheduledExecutorService(); ExecutionStateSampler executionStateSampler = new ExecutionStateSampler(options, System::currentTimeMillis); + final DataSampler dataSampler = new DataSampler(); // The logging client variable is not used per se, but during its lifetime (until close()) it // intercepts logging and sends it to the logging service. @@ -248,6 +251,12 @@ public static void main( FinalizeBundleHandler finalizeBundleHandler = new FinalizeBundleHandler(executorService); + // Create the sampler, if the experiment is enabled. + boolean shouldSample = + ExperimentalOptions.hasExperiment(options, ENABLE_DATA_SAMPLING_EXPERIMENT); + + // Retrieves the ProcessBundleDescriptor from cache. Requests the PBD from the Runner if it + // doesn't exist. Additionally, runs any graph modifications. Function getProcessBundleDescriptor = new Function() { private static final String PROCESS_BUNDLE_DESCRIPTORS = "ProcessBundleDescriptors"; @@ -279,7 +288,8 @@ private BeamFnApi.ProcessBundleDescriptor loadDescriptor(String id) { finalizeBundleHandler, metricsShortIds, executionStateSampler, - processWideCache); + processWideCache, + shouldSample ? dataSampler : null); logging.setProcessBundleHandler(processBundleHandler); BeamFnStatusClient beamFnStatusClient = null; @@ -327,6 +337,8 @@ private BeamFnApi.ProcessBundleDescriptor loadDescriptor(String id) { handlers.put( InstructionRequest.RequestCase.HARNESS_MONITORING_INFOS, processWideHandler::harnessMonitoringInfos); + handlers.put( + InstructionRequest.RequestCase.SAMPLE_DATA, dataSampler::handleDataSampleRequest); JvmInitializers.runBeforeProcessing(options); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index 560369a3907a..348b9a761fdf 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -52,6 +52,7 @@ import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry; import org.apache.beam.fn.harness.data.PTransformFunctionRegistry; +import org.apache.beam.fn.harness.debug.DataSampler; import org.apache.beam.fn.harness.state.BeamFnStateClient; import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache; import org.apache.beam.model.fnexecution.v1.BeamFnApi; @@ -164,6 +165,7 @@ public class ProcessBundleHandler { private final Cache processWideCache; @VisibleForTesting final BundleProcessorCache bundleProcessorCache; private final Set runnerCapabilities; + private final @Nullable DataSampler dataSampler; public ProcessBundleHandler( PipelineOptions options, @@ -174,7 +176,8 @@ public ProcessBundleHandler( FinalizeBundleHandler finalizeBundleHandler, ShortIdMap shortIds, ExecutionStateSampler executionStateSampler, - Cache processWideCache) { + Cache processWideCache, + @Nullable DataSampler dataSampler) { this( options, runnerCapabilities, @@ -186,7 +189,8 @@ public ProcessBundleHandler( executionStateSampler, REGISTERED_RUNNER_FACTORIES, processWideCache, - new BundleProcessorCache()); + new BundleProcessorCache(), + dataSampler); } @VisibleForTesting @@ -201,7 +205,8 @@ public ProcessBundleHandler( ExecutionStateSampler executionStateSampler, Map urnToPTransformRunnerFactoryMap, Cache processWideCache, - BundleProcessorCache bundleProcessorCache) { + BundleProcessorCache bundleProcessorCache, + @Nullable DataSampler dataSampler) { this.options = options; this.fnApiRegistry = fnApiRegistry; this.beamFnDataClient = beamFnDataClient; @@ -218,6 +223,7 @@ public ProcessBundleHandler( new UnknownPTransformRunnerFactory(urnToPTransformRunnerFactoryMap.keySet()); this.processWideCache = processWideCache; this.bundleProcessorCache = bundleProcessorCache; + this.dataSampler = dataSampler; } private void createRunnerAndConsumersForPTransformRecursively( @@ -771,7 +777,11 @@ private BundleProcessor createBundleProcessor( bundleProgressReporterAndRegistrar.register(stateTracker); PCollectionConsumerRegistry pCollectionConsumerRegistry = new PCollectionConsumerRegistry( - stateTracker, shortIds, bundleProgressReporterAndRegistrar, bundleDescriptor); + stateTracker, + shortIds, + bundleProgressReporterAndRegistrar, + bundleDescriptor, + dataSampler); HashSet processedPTransformIds = new HashSet<>(); PTransformFunctionRegistry startFunctionRegistry = diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java index 60b25d8b1376..5095be1be8fb 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.Random; +import javax.annotation.Nullable; import org.apache.beam.fn.harness.HandlesSplits; import org.apache.beam.fn.harness.control.BundleProgressReporter; import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionState; @@ -31,6 +32,8 @@ import org.apache.beam.fn.harness.control.Metrics; import org.apache.beam.fn.harness.control.Metrics.BundleCounter; import org.apache.beam.fn.harness.control.Metrics.BundleDistribution; +import org.apache.beam.fn.harness.debug.DataSampler; +import org.apache.beam.fn.harness.debug.OutputSampler; import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor; import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo; import org.apache.beam.model.pipeline.v1.RunnerApi; @@ -86,12 +89,22 @@ public static ConsumerAndMetadata forConsumer( private final BundleProgressReporter.Registrar bundleProgressReporterRegistrar; private final ProcessBundleDescriptor processBundleDescriptor; private final RehydratedComponents rehydratedComponents; + private final @Nullable DataSampler dataSampler; public PCollectionConsumerRegistry( ExecutionStateTracker stateTracker, ShortIdMap shortIdMap, BundleProgressReporter.Registrar bundleProgressReporterRegistrar, ProcessBundleDescriptor processBundleDescriptor) { + this(stateTracker, shortIdMap, bundleProgressReporterRegistrar, processBundleDescriptor, null); + } + + public PCollectionConsumerRegistry( + ExecutionStateTracker stateTracker, + ShortIdMap shortIdMap, + BundleProgressReporter.Registrar bundleProgressReporterRegistrar, + ProcessBundleDescriptor processBundleDescriptor, + @Nullable DataSampler dataSampler) { this.stateTracker = stateTracker; this.shortIdMap = shortIdMap; this.pCollectionIdsToConsumers = new HashMap<>(); @@ -105,6 +118,7 @@ public PCollectionConsumerRegistry( .putAllPcollections(processBundleDescriptor.getPcollectionsMap()) .putAllWindowingStrategies(processBundleDescriptor.getWindowingStrategiesMap()) .build()); + this.dataSampler = dataSampler; } /** @@ -200,15 +214,17 @@ public FnDataReceiver> getMultiplexingConsumer(String pCollecti if (consumerAndMetadatas.size() == 1) { ConsumerAndMetadata consumerAndMetadata = consumerAndMetadatas.get(0); if (consumerAndMetadata.getConsumer() instanceof HandlesSplits) { - return new SplittingMetricTrackingFnDataReceiver(pcId, coder, consumerAndMetadata); + return new SplittingMetricTrackingFnDataReceiver( + pcId, coder, consumerAndMetadata, dataSampler); } - return new MetricTrackingFnDataReceiver(pcId, coder, consumerAndMetadata); + return new MetricTrackingFnDataReceiver(pcId, coder, consumerAndMetadata, dataSampler); } else { /* TODO(SDF), Consider supporting splitting each consumer individually. This would never come up in the existing SDF expansion, but might be useful to support fused SDF nodes. This would require dedicated delivery of the split results to each of the consumers separately. */ - return new MultiplexingMetricTrackingFnDataReceiver(pcId, coder, consumerAndMetadatas); + return new MultiplexingMetricTrackingFnDataReceiver( + pcId, coder, consumerAndMetadatas, dataSampler); } }); } @@ -226,9 +242,13 @@ private class MetricTrackingFnDataReceiver implements FnDataReceiver sampledByteSizeDistribution; private final Coder coder; + private final @Nullable OutputSampler outputSampler; public MetricTrackingFnDataReceiver( - String pCollectionId, Coder coder, ConsumerAndMetadata consumerAndMetadata) { + String pCollectionId, + Coder coder, + ConsumerAndMetadata consumerAndMetadata, + @Nullable DataSampler dataSampler) { this.delegate = consumerAndMetadata.getConsumer(); this.executionState = consumerAndMetadata.getExecutionState(); @@ -264,6 +284,11 @@ public MetricTrackingFnDataReceiver( bundleProgressReporterRegistrar.register(sampledByteSizeUnderlyingDistribution); this.coder = coder; + if (dataSampler == null) { + this.outputSampler = null; + } else { + this.outputSampler = dataSampler.sampleOutput(pCollectionId, coder); + } } @Override @@ -274,6 +299,10 @@ public void accept(WindowedValue input) throws Exception { // we have window optimization. this.sampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder); + if (outputSampler != null) { + outputSampler.sample(input.getValue()); + } + // Use the ExecutionStateTracker and enter an appropriate state to track the // Process Bundle Execution time metric and also ensure user counters can get an appropriate // metrics container. @@ -300,9 +329,13 @@ private class MultiplexingMetricTrackingFnDataReceiver private final BundleCounter elementCountCounter; private final SampleByteSizeDistribution sampledByteSizeDistribution; private final Coder coder; + private final @Nullable OutputSampler outputSampler; public MultiplexingMetricTrackingFnDataReceiver( - String pCollectionId, Coder coder, List consumerAndMetadatas) { + String pCollectionId, + Coder coder, + List consumerAndMetadatas, + @Nullable DataSampler dataSampler) { this.consumerAndMetadatas = consumerAndMetadatas; HashMap labels = new HashMap<>(); @@ -337,6 +370,11 @@ public MultiplexingMetricTrackingFnDataReceiver( bundleProgressReporterRegistrar.register(sampledByteSizeUnderlyingDistribution); this.coder = coder; + if (dataSampler == null) { + this.outputSampler = null; + } else { + this.outputSampler = dataSampler.sampleOutput(pCollectionId, coder); + } } @Override @@ -347,6 +385,10 @@ public void accept(WindowedValue input) throws Exception { // when we have window optimization. this.sampledByteSizeDistribution.tryUpdate(input.getValue(), coder); + if (outputSampler != null) { + outputSampler.sample(input.getValue()); + } + // Use the ExecutionStateTracker and enter an appropriate state to track the // Process Bundle Execution time metric and also ensure user counters can get an appropriate // metrics container. We specifically don't use a for-each loop since it creates an iterator @@ -377,8 +419,11 @@ private class SplittingMetricTrackingFnDataReceiver extends MetricTrackingFnD private final HandlesSplits delegate; public SplittingMetricTrackingFnDataReceiver( - String pCollection, Coder coder, ConsumerAndMetadata consumerAndMetadata) { - super(pCollection, coder, consumerAndMetadata); + String pCollection, + Coder coder, + ConsumerAndMetadata consumerAndMetadata, + @Nullable DataSampler dataSampler) { + super(pCollection, coder, consumerAndMetadata, dataSampler); this.delegate = (HandlesSplits) consumerAndMetadata.getConsumer(); } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/debug/DataSampler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/debug/DataSampler.java new file mode 100644 index 000000000000..2a13b5dac3d3 --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/debug/DataSampler.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.debug; + +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.beam.model.fnexecution.v1.BeamFnApi; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.SampleDataResponse.ElementList; +import org.apache.beam.sdk.coders.Coder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The DataSampler is a global (per SDK Harness) object that facilitates taking and returning + * samples to the Runner Harness. The class is thread-safe with respect to executing + * ProcessBundleDescriptors. Meaning, different threads executing different PBDs can sample + * simultaneously, even if computing the same logical PCollection. + */ +public class DataSampler { + private static final Logger LOG = LoggerFactory.getLogger(DataSampler.class); + + /** + * Creates a DataSampler to sample every 1000 elements while keeping a maximum of 10 in memory. + */ + public DataSampler() { + this(10, 1000); + } + + /** + * @param maxSamples Sets the maximum number of samples held in memory at once. + * @param sampleEveryN Sets how often to sample. + */ + public DataSampler(int maxSamples, int sampleEveryN) { + checkArgument( + maxSamples > 0, + "Expected positive number of samples, did you mean to disable data sampling?"); + checkArgument( + sampleEveryN > 0, + "Expected positive number for sampling period, did you mean to disable data sampling?"); + this.maxSamples = maxSamples; + this.sampleEveryN = sampleEveryN; + } + + // Maximum number of elements in buffer. + private final int maxSamples; + + // Sampling rate. + private final int sampleEveryN; + + // The fully-qualified type is: Map[PCollectionId, OutputSampler]. In order to sample + // on a PCollection-basis and not per-bundle, this keeps track of shared samples between states. + private final Map> outputSamplers = new ConcurrentHashMap<>(); + + /** + * Creates and returns a class to sample the given PCollection in the given + * ProcessBundleDescriptor. Uses the given coder encode samples as bytes when responding to a + * SampleDataRequest. + * + *

Invoked by multiple bundle processing threads in parallel when a new bundle processor is + * being instantiated. + * + * @param pcollectionId The PCollection to take intermittent samples from. + * @param coder The coder associated with the PCollection. Coder may be from a nested context. + * @param The type of element contained in the PCollection. + * @return the OutputSampler corresponding to the unique PBD and PCollection. + */ + public OutputSampler sampleOutput(String pcollectionId, Coder coder) { + return (OutputSampler) + outputSamplers.computeIfAbsent( + pcollectionId, k -> new OutputSampler<>(coder, this.maxSamples, this.sampleEveryN)); + } + + /** + * Returns all collected samples. Thread-safe. + * + * @param request The instruction request from the FnApi. Filters based on the given + * SampleDataRequest. + * @return Returns all collected samples. + */ + public synchronized BeamFnApi.InstructionResponse.Builder handleDataSampleRequest( + BeamFnApi.InstructionRequest request) { + BeamFnApi.SampleDataRequest sampleDataRequest = request.getSampleData(); + + List pcollections = sampleDataRequest.getPcollectionIdsList(); + + // Safe to iterate as the ConcurrentHashMap will return each element at most once and will not + // throw ConcurrentModificationException. + BeamFnApi.SampleDataResponse.Builder response = BeamFnApi.SampleDataResponse.newBuilder(); + outputSamplers.forEach( + (pcollectionId, outputSampler) -> { + if (!pcollections.isEmpty() && !pcollections.contains(pcollectionId)) { + return; + } + + try { + response.putElementSamples( + pcollectionId, + ElementList.newBuilder().addAllElements(outputSampler.samples()).build()); + } catch (IOException e) { + LOG.warn("Could not encode elements from \"" + pcollectionId + "\" to bytes: " + e); + } + }); + + return BeamFnApi.InstructionResponse.newBuilder().setSampleData(response); + } +} diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/debug/OutputSampler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/debug/OutputSampler.java new file mode 100644 index 000000000000..326f2dbfe8f3 --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/debug/OutputSampler.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.debug; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.beam.model.fnexecution.v1.BeamFnApi; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.util.ByteStringOutputStream; + +/** + * This class holds samples for a single PCollection until queried by the parent DataSampler. This + * class is meant to hold only a limited number of elements in memory. So old values are constantly + * being overridden in a circular buffer. + * + * @param the element type of the PCollection. + */ +public class OutputSampler { + + // Temporarily holds elements until the SDK receives a sample data request. + private List buffer; + + // Maximum number of elements in buffer. + private final int maxElements; + + // Sampling rate. + private final int sampleEveryN; + + // Total number of samples taken. + private final AtomicLong numSamples = new AtomicLong(); + + // Index into the buffer of where to overwrite samples. + private int resampleIndex = 0; + + private final Coder coder; + + public OutputSampler(Coder coder, int maxElements, int sampleEveryN) { + this.coder = coder; + this.maxElements = maxElements; + this.sampleEveryN = sampleEveryN; + this.buffer = new ArrayList<>(this.maxElements); + } + + /** + * Samples every {@code sampleEveryN}th element or if it is part of the first 10 in the (local) + * PCollection. + * + *

This method is invoked in parallel by multiple bundle processing threads and in parallel to + * any {@link #samples} being returned to a thread handling a sample request. + * + * @param element the element to sample. + */ + public void sample(T element) { + // Only sample the first 10 elements then after every `sampleEveryN`th element. + long samples = numSamples.get() + 1; + + // This has eventual consistency. If there are many threads lazy setting, this will be set to + // the slowest thread accessing the atomic. But over time, it will still increase. This is ok + // because this is a debugging feature and doesn't need strict atomics. + numSamples.lazySet(samples); + if (samples > 10 && samples % sampleEveryN != 0) { + return; + } + + synchronized (this) { + // Fill buffer until maxElements. + if (buffer.size() < maxElements) { + buffer.add(element); + } else { + // Then rewrite sampled elements as a circular buffer. + buffer.set(resampleIndex, element); + resampleIndex = (resampleIndex + 1) % maxElements; + } + } + } + + /** + * Clears samples at end of call. This is to help mitigate memory use. + * + *

This method is invoked by a thread handling a data sampling request in parallel to any calls + * to {@link #sample}. + * + * @return samples taken since last call. + */ + public List samples() throws IOException { + List ret = new ArrayList<>(); + + // Serializing can take a lot of CPU time for larger or complex elements. Copy the array here + // so as to not slow down the main processing hot path. + List bufferToSend; + int sampleIndex = 0; + synchronized (this) { + bufferToSend = buffer; + sampleIndex = resampleIndex; + buffer = new ArrayList<>(maxElements); + resampleIndex = 0; + } + + ByteStringOutputStream stream = new ByteStringOutputStream(); + for (int i = 0; i < bufferToSend.size(); i++) { + int index = (sampleIndex + i) % bufferToSend.size(); + // This is deprecated, but until this is fully removed, this specifically needs the nested + // context. This is because the SDK will need to decode the sampled elements with the + // ToStringFn. + coder.encode(bufferToSend.get(index), stream, Coder.Context.NESTED); + ret.add( + BeamFnApi.SampledElement.newBuilder().setElement(stream.toByteStringAndReset()).build()); + } + + return ret; + } +} diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/debug/package-info.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/debug/package-info.java new file mode 100644 index 000000000000..978bcd346d47 --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/debug/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Classes and utilities related to debugging features. */ +package org.apache.beam.fn.harness.debug; diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index 7df9ed2f894d..52bb72f97894 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -376,7 +376,8 @@ public void testTrySplitBeforeBundleDoesNotFail() { executionStateSampler, ImmutableMap.of(), Caches.noop(), - new BundleProcessorCache()); + new BundleProcessorCache(), + null /* dataSampler */); BeamFnApi.InstructionResponse response = handler @@ -406,7 +407,8 @@ public void testProgressBeforeBundleDoesNotFail() throws Exception { executionStateSampler, ImmutableMap.of(), Caches.noop(), - new BundleProcessorCache()); + new BundleProcessorCache(), + null /* dataSampler */); handler.progress( BeamFnApi.InstructionRequest.newBuilder() @@ -485,7 +487,8 @@ public void testOrderOfStartAndFinishCalls() throws Exception { DATA_INPUT_URN, startFinishRecorder, DATA_OUTPUT_URN, startFinishRecorder), Caches.noop(), - new BundleProcessorCache()); + new BundleProcessorCache(), + null /* dataSampler */); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() @@ -589,7 +592,8 @@ public void testOrderOfSetupTeardownCalls() throws Exception { executionStateSampler, urnToPTransformRunnerFactoryMap, Caches.noop(), - new BundleProcessorCache()); + new BundleProcessorCache(), + null /* dataSampler */); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() @@ -640,7 +644,8 @@ public void testBundleProcessorIsResetWhenAddedBackToCache() throws Exception { executionStateSampler, ImmutableMap.of(DATA_INPUT_URN, (context) -> null), Caches.noop(), - new TestBundleProcessorCache()); + new TestBundleProcessorCache(), + null /* dataSampler */); assertThat(TestBundleProcessor.resetCnt, equalTo(0)); @@ -806,7 +811,8 @@ public void testCreatingPTransformExceptionsArePropagated() throws Exception { throw new IllegalStateException("TestException"); }), Caches.noop(), - new BundleProcessorCache()); + new BundleProcessorCache(), + null /* dataSampler */); assertThrows( "TestException", IllegalStateException.class, @@ -856,7 +862,8 @@ public void testBundleFinalizationIsPropagated() throws Exception { return null; }), Caches.noop(), - new BundleProcessorCache()); + new BundleProcessorCache(), + null /* dataSampler */); BeamFnApi.InstructionResponse.Builder response = handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() @@ -909,7 +916,8 @@ public void testPTransformStartExceptionsArePropagated() { return null; }), Caches.noop(), - new BundleProcessorCache()); + new BundleProcessorCache(), + null /* dataSampler */); assertThrows( "TestException", IllegalStateException.class, @@ -1086,7 +1094,8 @@ public void onCompleted() {} executionStateSampler, urnToPTransformRunnerFactoryMap, Caches.noop(), - new BundleProcessorCache()); + new BundleProcessorCache(), + null /* dataSampler */); } @Test @@ -1418,7 +1427,8 @@ public void testInstructionIsUnregisteredFromBeamFnDataClientOnSuccess() throws return null; }), Caches.noop(), - new BundleProcessorCache()); + new BundleProcessorCache(), + null /* dataSampler */); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() .setInstructionId("instructionId") @@ -1490,7 +1500,8 @@ public void testDataProcessingExceptionsArePropagated() throws Exception { return null; }), Caches.noop(), - new BundleProcessorCache()); + new BundleProcessorCache(), + null /* dataSampler */); assertThrows( "TestException", IllegalStateException.class, @@ -1539,7 +1550,8 @@ public void testPTransformFinishExceptionsArePropagated() throws Exception { return null; }), Caches.noop(), - new BundleProcessorCache()); + new BundleProcessorCache(), + null /* dataSampler */); assertThrows( "TestException", IllegalStateException.class, @@ -1634,7 +1646,8 @@ private void doStateCalls(BeamFnStateClient beamFnStateClient) { } }), Caches.noop(), - new BundleProcessorCache()); + new BundleProcessorCache(), + null /* dataSampler */); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() .setProcessBundle( @@ -1684,7 +1697,8 @@ private void doStateCalls(BeamFnStateClient beamFnStateClient) { } }), Caches.noop(), - new BundleProcessorCache()); + new BundleProcessorCache(), + null /* dataSampler */); assertThrows( "State API calls are unsupported", IllegalStateException.class, @@ -1786,7 +1800,8 @@ public void reset() { executionStateSampler, ImmutableMap.of(DATA_INPUT_URN, startFinishGuard), Caches.noop(), - bundleProcessorCache); + bundleProcessorCache, + null /* dataSampler */); AtomicBoolean progressShouldExit = new AtomicBoolean(); Future bundleProcessorTask = @@ -1914,7 +1929,8 @@ public Object createRunnerForPTransform(Context context) throws IOException { } }), Caches.noop(), - new BundleProcessorCache()); + new BundleProcessorCache(), + null /* dataSampler */); assertThrows( "Timers are unsupported", IllegalStateException.class, diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java index 35bd5697adc0..c24f016b5cc1 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java @@ -21,6 +21,8 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; @@ -39,6 +41,8 @@ import org.apache.beam.fn.harness.control.BundleProgressReporter; import org.apache.beam.fn.harness.control.ExecutionStateSampler; import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker; +import org.apache.beam.fn.harness.debug.DataSampler; +import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor; import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo; import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection; @@ -56,6 +60,7 @@ import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.util.ByteStringOutputStream; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable; import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterator; @@ -507,6 +512,61 @@ public void testLazyByteSizeEstimation() throws Exception { assertThat(result, containsInAnyOrder(expected.toArray())); } + /** + * Test that element samples are taken when a DataSampler is present. + * + * @throws Exception + */ + @Test + public void dataSampling() throws Exception { + final String pTransformIdA = "pTransformIdA"; + + ShortIdMap shortIds = new ShortIdMap(); + BundleProgressReporter.InMemory reporterAndRegistrar = new BundleProgressReporter.InMemory(); + DataSampler dataSampler = new DataSampler(); + PCollectionConsumerRegistry consumers = + new PCollectionConsumerRegistry( + sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR, dataSampler); + FnDataReceiver> consumerA1 = mock(FnDataReceiver.class); + + consumers.register(P_COLLECTION_A, pTransformIdA, pTransformIdA + "Name", consumerA1); + + FnDataReceiver> wrapperConsumer = + (FnDataReceiver>) + (FnDataReceiver) consumers.getMultiplexingConsumer(P_COLLECTION_A); + String elementValue = "elem"; + WindowedValue element = valueInGlobalWindow(elementValue); + int numElements = 10; + for (int i = 0; i < numElements; i++) { + wrapperConsumer.accept(element); + } + + BeamFnApi.InstructionRequest request = + BeamFnApi.InstructionRequest.newBuilder() + .setSampleData(BeamFnApi.SampleDataRequest.newBuilder()) + .build(); + BeamFnApi.InstructionResponse response = dataSampler.handleDataSampleRequest(request).build(); + + Map elementSamplesMap = + response.getSampleData().getElementSamplesMap(); + + assertFalse(elementSamplesMap.isEmpty()); + + BeamFnApi.SampleDataResponse.ElementList elementList = elementSamplesMap.get(P_COLLECTION_A); + assertNotNull(elementList); + + List expectedSamples = new ArrayList<>(); + StringUtf8Coder coder = StringUtf8Coder.of(); + for (int i = 0; i < numElements; i++) { + ByteStringOutputStream stream = new ByteStringOutputStream(); + coder.encode(elementValue, stream); + expectedSamples.add( + BeamFnApi.SampledElement.newBuilder().setElement(stream.toByteStringAndReset()).build()); + } + + assertTrue(elementList.getElementsList().containsAll(expectedSamples)); + } + private static class TestElementByteSizeObservableIterable extends ElementByteSizeObservableIterable> { private List elements; diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/debug/DataSamplerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/debug/DataSamplerTest.java new file mode 100644 index 000000000000..4b874dd7e980 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/debug/DataSamplerTest.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.debug; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import org.apache.beam.model.fnexecution.v1.BeamFnApi; +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class DataSamplerTest { + byte[] encodeInt(Integer i) throws IOException { + VarIntCoder coder = VarIntCoder.of(); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + coder.encode(i, stream, Coder.Context.NESTED); + return stream.toByteArray(); + } + + byte[] encodeString(String s) throws IOException { + StringUtf8Coder coder = StringUtf8Coder.of(); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + coder.encode(s, stream, Coder.Context.NESTED); + return stream.toByteArray(); + } + + byte[] encodeByteArray(byte[] b) throws IOException { + ByteArrayCoder coder = ByteArrayCoder.of(); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + coder.encode(b, stream, Coder.Context.NESTED); + return stream.toByteArray(); + } + + BeamFnApi.InstructionResponse getAllSamples(DataSampler dataSampler) { + BeamFnApi.InstructionRequest request = + BeamFnApi.InstructionRequest.newBuilder() + .setSampleData(BeamFnApi.SampleDataRequest.newBuilder().build()) + .build(); + return dataSampler.handleDataSampleRequest(request).build(); + } + + BeamFnApi.InstructionResponse getSamplesForPCollection( + DataSampler dataSampler, String pcollection) { + BeamFnApi.InstructionRequest request = + BeamFnApi.InstructionRequest.newBuilder() + .setSampleData( + BeamFnApi.SampleDataRequest.newBuilder().addPcollectionIds(pcollection).build()) + .build(); + return dataSampler.handleDataSampleRequest(request).build(); + } + + BeamFnApi.InstructionResponse getSamplesForPCollections( + DataSampler dataSampler, Iterable pcollections) { + BeamFnApi.InstructionRequest request = + BeamFnApi.InstructionRequest.newBuilder() + .setSampleData( + BeamFnApi.SampleDataRequest.newBuilder().addAllPcollectionIds(pcollections).build()) + .build(); + return dataSampler.handleDataSampleRequest(request).build(); + } + + void assertHasSamples( + BeamFnApi.InstructionResponse response, String pcollection, Iterable elements) { + Map elementSamplesMap = + response.getSampleData().getElementSamplesMap(); + + assertFalse(elementSamplesMap.isEmpty()); + + BeamFnApi.SampleDataResponse.ElementList elementList = elementSamplesMap.get(pcollection); + assertNotNull(elementList); + + List expectedSamples = new ArrayList<>(); + for (byte[] el : elements) { + expectedSamples.add( + BeamFnApi.SampledElement.newBuilder().setElement(ByteString.copyFrom(el)).build()); + } + + assertTrue(elementList.getElementsList().containsAll(expectedSamples)); + } + + /** + * Smoke test that a samples show in the output map. + * + * @throws Exception + */ + @Test + public void testSingleOutput() throws Exception { + DataSampler sampler = new DataSampler(); + + VarIntCoder coder = VarIntCoder.of(); + sampler.sampleOutput("pcollection-id", coder).sample(1); + + BeamFnApi.InstructionResponse samples = getAllSamples(sampler); + assertHasSamples(samples, "pcollection-id", Collections.singleton(encodeInt(1))); + } + + /** + * Smoke test that a sample shows in the output map. + * + * @throws Exception + */ + @Test + public void testNestedContext() throws Exception { + DataSampler sampler = new DataSampler(); + + String rawString = "hello"; + byte[] byteArray = rawString.getBytes(StandardCharsets.US_ASCII); + ByteArrayCoder coder = ByteArrayCoder.of(); + sampler.sampleOutput("pcollection-id", coder).sample(byteArray); + + BeamFnApi.InstructionResponse samples = getAllSamples(sampler); + assertHasSamples(samples, "pcollection-id", Collections.singleton(encodeByteArray(byteArray))); + } + + /** + * Test that sampling multiple PCollections under the same descriptor is OK. + * + * @throws Exception + */ + @Test + public void testMultipleOutputs() throws Exception { + DataSampler sampler = new DataSampler(); + + VarIntCoder coder = VarIntCoder.of(); + sampler.sampleOutput("pcollection-id-1", coder).sample(1); + sampler.sampleOutput("pcollection-id-2", coder).sample(2); + + BeamFnApi.InstructionResponse samples = getAllSamples(sampler); + assertHasSamples(samples, "pcollection-id-1", Collections.singleton(encodeInt(1))); + assertHasSamples(samples, "pcollection-id-2", Collections.singleton(encodeInt(2))); + } + + /** + * Test that the response contains samples from the same PCollection across descriptors. + * + * @throws Exception + */ + @Test + public void testMultipleSamePCollections() throws Exception { + DataSampler sampler = new DataSampler(); + + VarIntCoder coder = VarIntCoder.of(); + sampler.sampleOutput("pcollection-id", coder).sample(1); + sampler.sampleOutput("pcollection-id", coder).sample(2); + + BeamFnApi.InstructionResponse samples = getAllSamples(sampler); + assertHasSamples(samples, "pcollection-id", ImmutableList.of(encodeInt(1), encodeInt(2))); + } + + void generateStringSamples(DataSampler sampler) { + StringUtf8Coder coder = StringUtf8Coder.of(); + sampler.sampleOutput("a", coder).sample("a1"); + sampler.sampleOutput("a", coder).sample("a2"); + sampler.sampleOutput("b", coder).sample("b1"); + sampler.sampleOutput("b", coder).sample("b2"); + sampler.sampleOutput("c", coder).sample("c1"); + sampler.sampleOutput("c", coder).sample("c2"); + } + + /** + * Test that samples can be filtered based on PCollection id. + * + * @throws Exception + */ + @Test + public void testFiltersSinglePCollectionId() throws Exception { + DataSampler sampler = new DataSampler(10, 10); + generateStringSamples(sampler); + + BeamFnApi.InstructionResponse samples = getSamplesForPCollection(sampler, "a"); + assertHasSamples(samples, "a", ImmutableList.of(encodeString("a1"), encodeString("a2"))); + } + + /** + * Test that samples can be filtered both on PCollection and ProcessBundleDescriptor id. + * + * @throws Exception + */ + @Test + public void testFiltersMultiplePCollectionIds() throws Exception { + List pcollectionIds = ImmutableList.of("a", "c"); + + DataSampler sampler = new DataSampler(10, 10); + generateStringSamples(sampler); + + BeamFnApi.InstructionResponse samples = getSamplesForPCollections(sampler, pcollectionIds); + assertThat(samples.getSampleData().getElementSamplesMap().size(), equalTo(2)); + assertHasSamples(samples, "a", ImmutableList.of(encodeString("a1"), encodeString("a2"))); + assertHasSamples(samples, "c", ImmutableList.of(encodeString("c1"), encodeString("c2"))); + } + + /** + * Test that samples can be taken from the DataSampler while adding new OutputSamplers. This fails + * with a ConcurrentModificationException if there is a bug. + * + * @throws Exception + */ + @Test + public void testConcurrentNewSampler() throws Exception { + DataSampler sampler = new DataSampler(); + VarIntCoder coder = VarIntCoder.of(); + + // Make threads that will create 100 individual OutputSamplers each. + Thread[] sampleThreads = new Thread[100]; + CountDownLatch startSignal = new CountDownLatch(1); + CountDownLatch doneSignal = new CountDownLatch(sampleThreads.length); + + for (int i = 0; i < sampleThreads.length; i++) { + sampleThreads[i] = + new Thread( + () -> { + try { + startSignal.await(); + } catch (InterruptedException e) { + return; + } + + for (int j = 0; j < 100; j++) { + sampler.sampleOutput("pcollection-" + j, coder).sample(0); + } + + doneSignal.countDown(); + }); + sampleThreads[i].start(); + } + + startSignal.countDown(); + while (doneSignal.getCount() > 0) { + sampler.handleDataSampleRequest( + BeamFnApi.InstructionRequest.newBuilder() + .setSampleData(BeamFnApi.SampleDataRequest.newBuilder()) + .build()); + } + + for (Thread sampleThread : sampleThreads) { + sampleThread.join(); + } + } +} diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/debug/OutputSamplerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/debug/OutputSamplerTest.java new file mode 100644 index 000000000000..953ccce9e235 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/debug/OutputSamplerTest.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.debug; + +import static junit.framework.TestCase.assertEquals; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import org.apache.beam.model.fnexecution.v1.BeamFnApi; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.ByteString; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class OutputSamplerTest { + public BeamFnApi.SampledElement encodeInt(Integer i) throws IOException { + VarIntCoder coder = VarIntCoder.of(); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + coder.encode(i, stream); + return BeamFnApi.SampledElement.newBuilder() + .setElement(ByteString.copyFrom(stream.toByteArray())) + .build(); + } + + /** + * Test that the first N are always sampled. + * + * @throws Exception when encoding fails (shouldn't happen). + */ + @Test + public void testSamplesFirstN() throws Exception { + VarIntCoder coder = VarIntCoder.of(); + OutputSampler outputSampler = new OutputSampler<>(coder, 10, 10); + + // Purposely go over maxSamples and sampleEveryN. This helps to increase confidence. + for (int i = 0; i < 15; ++i) { + outputSampler.sample(i); + } + + // The expected list is only 0..9 inclusive. + List expected = new ArrayList<>(); + for (int i = 0; i < 10; ++i) { + expected.add(encodeInt(i)); + } + + List samples = outputSampler.samples(); + assertThat(samples, containsInAnyOrder(expected.toArray())); + } + + /** + * Test that the previous values are overwritten and only the most recent `maxSamples` are kept. + * + * @throws Exception when encoding fails (shouldn't happen). + */ + @Test + public void testActsLikeCircularBuffer() throws Exception { + VarIntCoder coder = VarIntCoder.of(); + OutputSampler outputSampler = new OutputSampler<>(coder, 5, 20); + + for (int i = 0; i < 100; ++i) { + outputSampler.sample(i); + } + + // The first 10 are always sampled, but with maxSamples = 5, the first ten are downsampled to + // 4..9 inclusive. Then, + // the 20th element is sampled (19) and every 20 after. + List expected = new ArrayList<>(); + expected.add(encodeInt(19)); + expected.add(encodeInt(39)); + expected.add(encodeInt(59)); + expected.add(encodeInt(79)); + expected.add(encodeInt(99)); + + List samples = outputSampler.samples(); + assertThat(samples, containsInAnyOrder(expected.toArray())); + } + + /** + * Test that sampling a PCollection while retrieving samples from multiple threads is ok. + * + * @throws Exception + */ + @Test + public void testConcurrentSamples() throws Exception { + VarIntCoder coder = VarIntCoder.of(); + OutputSampler outputSampler = new OutputSampler<>(coder, 10, 2); + + CountDownLatch startSignal = new CountDownLatch(1); + CountDownLatch doneSignal = new CountDownLatch(2); + + // Iteration count was empirically chosen to have a high probability of failure without the + // test going for too long. + // Generates a range of numbers from 0 to 1000000. + Thread sampleThreadA = + new Thread( + () -> { + try { + startSignal.await(); + } catch (InterruptedException e) { + return; + } + + for (int i = 0; i < 1000000; i++) { + outputSampler.sample(i); + } + + doneSignal.countDown(); + }); + + // Generates a range of numbers from -1000000 to 0. + Thread sampleThreadB = + new Thread( + () -> { + try { + startSignal.await(); + } catch (InterruptedException e) { + return; + } + + for (int i = -1000000; i < 0; i++) { + outputSampler.sample(i); + } + + doneSignal.countDown(); + }); + + // Ready the threads. + sampleThreadA.start(); + sampleThreadB.start(); + + // Start the threads at the same time. + startSignal.countDown(); + + // Generate contention by sampling at the same time as the samples are generated. + List samples = new ArrayList<>(); + while (doneSignal.getCount() > 0) { + samples.addAll(outputSampler.samples()); + } + + // Stop the threads and sort the samples from which thread it came from. + sampleThreadA.join(); + sampleThreadB.join(); + List samplesFromThreadA = new ArrayList<>(); + List samplesFromThreadB = new ArrayList<>(); + for (BeamFnApi.SampledElement sampledElement : samples) { + int el = coder.decode(sampledElement.getElement().newInput()); + if (el >= 0) { + samplesFromThreadA.add(el); + } else { + samplesFromThreadB.add(el); + } + } + + // Copy the array and sort it. + List sortedSamplesFromThreadA = new ArrayList<>(samplesFromThreadA); + List sortedSamplesFromThreadB = new ArrayList<>(samplesFromThreadB); + Collections.sort(sortedSamplesFromThreadA); + Collections.sort(sortedSamplesFromThreadB); + + // Order is preserved when getting the samples. If there is a weird race condition, these + // numbers may be out of order. + assertEquals(samplesFromThreadA, sortedSamplesFromThreadA); + assertEquals(samplesFromThreadB, sortedSamplesFromThreadB); + } +}