diff --git a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/ProcessingExceptionHandlerIntegrationTest.java b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/ProcessingExceptionHandlerIntegrationTest.java index 13e291e887cc3..38711093ff8af 100644 --- a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/ProcessingExceptionHandlerIntegrationTest.java +++ b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/ProcessingExceptionHandlerIntegrationTest.java @@ -16,14 +16,17 @@ */ package org.apache.kafka.streams.integration; +import org.apache.kafka.clients.producer.ProducerRecord; import org.apache.kafka.common.MetricName; import org.apache.kafka.common.serialization.Serdes; import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; import org.apache.kafka.streams.KeyValue; import org.apache.kafka.streams.KeyValueTimestamp; import org.apache.kafka.streams.StreamsBuilder; import org.apache.kafka.streams.StreamsConfig; import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.Topology; import org.apache.kafka.streams.TopologyTestDriver; import org.apache.kafka.streams.errors.ErrorHandlerContext; import org.apache.kafka.streams.errors.LogAndContinueProcessingExceptionHandler; @@ -31,14 +34,22 @@ import org.apache.kafka.streams.errors.ProcessingExceptionHandler; import org.apache.kafka.streams.errors.StreamsException; import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.StreamJoined; import org.apache.kafka.streams.processor.api.ContextualProcessor; import org.apache.kafka.streams.processor.api.ProcessorSupplier; import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.KeyValueStore; import org.apache.kafka.test.MockProcessorSupplier; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import java.time.Duration; import java.time.Instant; @@ -48,6 +59,7 @@ import java.util.Map; import java.util.Properties; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Stream; import static org.apache.kafka.common.utils.Utils.mkEntry; import static org.apache.kafka.common.utils.Utils.mkMap; @@ -385,6 +397,131 @@ public void shouldStopProcessingWhenFatalUserExceptionProcessingExceptionHandler } } + static Stream sourceRawRecordTopologyTestCases() { + // Validate source raw key and source raw value for fully stateless topology + final List> statelessTopologyEvent = List.of(new ProducerRecord<>("TOPIC_NAME", "ID123-1", "ID123-A1")); + final StreamsBuilder statelessTopologyBuilder = new StreamsBuilder(); + statelessTopologyBuilder + .stream("TOPIC_NAME", Consumed.with(Serdes.String(), Serdes.String())) + .selectKey((key, value) -> "newKey") + .mapValues(value -> { + throw new RuntimeException("Error"); + }); + + // Validate source raw key and source raw value for processing exception in aggregator with caching enabled + final List> cacheAggregateExceptionInAggregatorEvent = List.of(new ProducerRecord<>("TOPIC_NAME", "INITIAL-KEY123-1", "ID123-A1")); + final StreamsBuilder cacheAggregateExceptionInAggregatorTopologyBuilder = new StreamsBuilder(); + cacheAggregateExceptionInAggregatorTopologyBuilder + .stream("TOPIC_NAME", Consumed.with(Serdes.String(), Serdes.String())) + .groupBy((key, value) -> "ID123-1", Grouped.with(Serdes.String(), Serdes.String())) + .aggregate(() -> "initialValue", + (key, value, aggregate) -> { + throw new RuntimeException("Error"); + }, + Materialized.>as("aggregate") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + .withCachingEnabled()); + + // Validate source raw key and source raw value for processing exception after aggregation with caching enabled + final List> cacheAggregateExceptionAfterAggregationEvent = List.of(new ProducerRecord<>("TOPIC_NAME", "INITIAL-KEY123-1", "ID123-A1")); + final StreamsBuilder cacheAggregateExceptionAfterAggregationTopologyBuilder = new StreamsBuilder(); + cacheAggregateExceptionAfterAggregationTopologyBuilder + .stream("TOPIC_NAME", Consumed.with(Serdes.String(), Serdes.String())) + .groupBy((key, value) -> "ID123-1", Grouped.with(Serdes.String(), Serdes.String())) + .aggregate(() -> "initialValue", + (key, value, aggregate) -> value, + Materialized.>as("aggregate") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + .withCachingEnabled()) + .mapValues(value -> { + throw new RuntimeException("Error"); + }); + + // Validate source raw key and source raw value for processing exception after aggregation with caching disabled + final List> noCacheAggregateExceptionAfterAggregationEvents = List.of(new ProducerRecord<>("TOPIC_NAME", "INITIAL-KEY123-1", "ID123-A1")); + final StreamsBuilder noCacheAggregateExceptionAfterAggregationTopologyBuilder = new StreamsBuilder(); + noCacheAggregateExceptionAfterAggregationTopologyBuilder + .stream("TOPIC_NAME", Consumed.with(Serdes.String(), Serdes.String())) + .groupBy((key, value) -> "ID123-1", Grouped.with(Serdes.String(), Serdes.String())) + .aggregate(() -> "initialValue", + (key, value, aggregate) -> value, + Materialized.>as("aggregate") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + .withCachingDisabled()) + .mapValues(value -> { + throw new RuntimeException("Error"); + }); + + // Validate source raw key and source raw value for processing exception after table creation with caching enabled + final List> cacheTableEvents = List.of(new ProducerRecord<>("TOPIC_NAME", "ID123-1", "ID123-A1")); + final StreamsBuilder cacheTableTopologyBuilder = new StreamsBuilder(); + cacheTableTopologyBuilder + .table("TOPIC_NAME", Consumed.with(Serdes.String(), Serdes.String()), + Materialized.>as("table") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + .withCachingEnabled()) + .mapValues(value -> { + throw new RuntimeException("Error"); + }); + + // Validate source raw key and source raw value for processing exception in join + final List> joinEvents = List.of( + new ProducerRecord<>("TOPIC_NAME_2", "INITIAL-KEY123-1", "ID123-A1"), + new ProducerRecord<>("TOPIC_NAME", "INITIAL-KEY123-2", "ID123-A1") + ); + final StreamsBuilder joinTopologyBuilder = new StreamsBuilder(); + joinTopologyBuilder + .stream("TOPIC_NAME", Consumed.with(Serdes.String(), Serdes.String())) + .selectKey((key, value) -> "ID123-1") + .leftJoin(joinTopologyBuilder.stream("TOPIC_NAME_2", Consumed.with(Serdes.String(), Serdes.String())) + .selectKey((key, value) -> "ID123-1"), + (key, left, right) -> { + throw new RuntimeException("Error"); + }, + JoinWindows.ofTimeDifferenceAndGrace(Duration.ofMinutes(5), Duration.ofMinutes(1)), + StreamJoined.with( + Serdes.String(), Serdes.String(), Serdes.String()) + .withName("join-rekey") + .withStoreName("join-store")); + + return Stream.of( + Arguments.of(statelessTopologyEvent, statelessTopologyBuilder.build()), + Arguments.of(cacheAggregateExceptionInAggregatorEvent, cacheAggregateExceptionInAggregatorTopologyBuilder.build()), + Arguments.of(cacheAggregateExceptionAfterAggregationEvent, noCacheAggregateExceptionAfterAggregationTopologyBuilder.build()), + Arguments.of(noCacheAggregateExceptionAfterAggregationEvents, cacheAggregateExceptionInAggregatorTopologyBuilder.build()), + Arguments.of(cacheTableEvents, cacheTableTopologyBuilder.build()), + Arguments.of(joinEvents, joinTopologyBuilder.build()) + ); + } + + @ParameterizedTest + @MethodSource("sourceRawRecordTopologyTestCases") + public void shouldVerifySourceRawKeyAndSourceRawValuePresentOrNotInErrorHandlerContext(final List> events, + final Topology topology) { + final Properties properties = new Properties(); + properties.put(StreamsConfig.PROCESSING_EXCEPTION_HANDLER_CLASS_CONFIG, + AssertSourceRawRecordProcessingExceptionHandlerMockTest.class); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, properties, Instant.ofEpochMilli(0L))) { + for (final ProducerRecord event : events) { + final TestInputTopic inputTopic = driver.createInputTopic(event.topic(), new StringSerializer(), new StringSerializer()); + + final String key = event.key(); + final String value = event.value(); + + if (event.topic().equals("TOPIC_NAME")) { + assertThrows(StreamsException.class, () -> inputTopic.pipeInput(key, value, TIMESTAMP)); + } else { + inputTopic.pipeInput(event.key(), event.value(), TIMESTAMP); + } + } + } + } + public static class ContinueProcessingExceptionHandlerMockTest implements ProcessingExceptionHandler { @Override public ProcessingExceptionHandler.ProcessingHandlerResponse handle(final ErrorHandlerContext context, final Record record, final Exception exception) { @@ -422,10 +559,28 @@ private static void assertProcessingExceptionHandlerInputs(final ErrorHandlerCon assertTrue(Arrays.asList("ID123-A2", "ID123-A5").contains((String) record.value())); assertEquals("TOPIC_NAME", context.topic()); assertEquals("KSTREAM-PROCESSOR-0000000003", context.processorNodeId()); + assertTrue(Arrays.equals("ID123-2-ERR".getBytes(), context.sourceRawKey()) + || Arrays.equals("ID123-5-ERR".getBytes(), context.sourceRawKey())); + assertTrue(Arrays.equals("ID123-A2".getBytes(), context.sourceRawValue()) + || Arrays.equals("ID123-A5".getBytes(), context.sourceRawValue())); assertEquals(TIMESTAMP.toEpochMilli(), context.timestamp()); assertTrue(exception.getMessage().contains("Exception should be handled by processing exception handler")); } + public static class AssertSourceRawRecordProcessingExceptionHandlerMockTest implements ProcessingExceptionHandler { + @Override + public ProcessingExceptionHandler.ProcessingHandlerResponse handle(final ErrorHandlerContext context, final Record record, final Exception exception) { + assertEquals("ID123-1", Serdes.String().deserializer().deserialize("topic", context.sourceRawKey())); + assertEquals("ID123-A1", Serdes.String().deserializer().deserialize("topic", context.sourceRawValue())); + return ProcessingExceptionHandler.ProcessingHandlerResponse.FAIL; + } + + @Override + public void configure(final Map configs) { + // No-op + } + } + /** * Metric name for dropped records total. * diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/ErrorHandlerContext.java b/streams/src/main/java/org/apache/kafka/streams/errors/ErrorHandlerContext.java index d471673a48ed4..59ccab6fbf38c 100644 --- a/streams/src/main/java/org/apache/kafka/streams/errors/ErrorHandlerContext.java +++ b/streams/src/main/java/org/apache/kafka/streams/errors/ErrorHandlerContext.java @@ -147,4 +147,38 @@ public interface ErrorHandlerContext { * @return The timestamp. */ long timestamp(); + + /** + * Return the non-deserialized byte[] of the input message key if the context has been triggered by a message. + * + *

If this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, it will return null. + * + *

If this method is invoked in a sub-topology due to a repartition, the returned key would be one sent + * to the repartition topic. + * + *

Always returns null if this method is invoked within a + * ProductionExceptionHandler.handle(ErrorHandlerContext, ProducerRecord, Exception) + * + * @return the raw byte of the key of the source message + */ + byte[] sourceRawKey(); + + /** + * Return the non-deserialized byte[] of the input message value if the context has been triggered by a message. + * + *

If this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, it will return {@code null}. + * + *

If this method is invoked in a sub-topology due to a repartition, the returned key would be one sent + * to the repartition topic. + * + *

Always returns null if this method is invoked within a + * ProductionExceptionHandler.handle(ErrorHandlerContext, ProducerRecord, Exception) + * + * @return the raw byte of the value of the source message + */ + byte[] sourceRawValue(); } diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/internals/DefaultErrorHandlerContext.java b/streams/src/main/java/org/apache/kafka/streams/errors/internals/DefaultErrorHandlerContext.java index efaa6d57e7acc..0e85ce68c0369 100644 --- a/streams/src/main/java/org/apache/kafka/streams/errors/internals/DefaultErrorHandlerContext.java +++ b/streams/src/main/java/org/apache/kafka/streams/errors/internals/DefaultErrorHandlerContext.java @@ -33,6 +33,8 @@ public class DefaultErrorHandlerContext implements ErrorHandlerContext { private final Headers headers; private final String processorNodeId; private final TaskId taskId; + private final byte[] sourceRawKey; + private final byte[] sourceRawValue; private final long timestamp; private final ProcessorContext processorContext; @@ -44,7 +46,9 @@ public DefaultErrorHandlerContext(final ProcessorContext processorContext, final Headers headers, final String processorNodeId, final TaskId taskId, - final long timestamp) { + final long timestamp, + final byte[] sourceRawKey, + final byte[] sourceRawValue) { this.topic = topic; this.partition = partition; this.offset = offset; @@ -53,6 +57,8 @@ public DefaultErrorHandlerContext(final ProcessorContext processorContext, this.taskId = taskId; this.processorContext = processorContext; this.timestamp = timestamp; + this.sourceRawKey = sourceRawKey; + this.sourceRawValue = sourceRawValue; } @Override @@ -90,6 +96,14 @@ public long timestamp() { return timestamp; } + public byte[] sourceRawKey() { + return sourceRawKey; + } + + public byte[] sourceRawValue() { + return sourceRawValue; + } + @Override public String toString() { // we do exclude headers on purpose, to not accidentally log user data diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/RecordContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/RecordContext.java index 6b6fd91c85355..f77d4f454ce37 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/RecordContext.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/RecordContext.java @@ -110,4 +110,31 @@ public interface RecordContext { */ Headers headers(); + /** + * Return the non-deserialized byte[] of the input message key if the context has been triggered by a message. + * + *

If this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, it will return {@code null}. + * + *

If this method is invoked in a sub-topology due to a repartition, the returned key would be one sent + * to the repartition topic. + * + * @return the raw byte of the key of the source message + */ + byte[] sourceRawKey(); + + /** + * Return the non-deserialized byte[] of the input message value if the context has been triggered by a message. + * + *

If this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, it will return {@code null}. + * + *

If this method is invoked in a sub-topology due to a repartition, the returned key would be one sent + * to the repartition topic. + * + * @return the raw byte of the value of the source message + */ + byte[] sourceRawValue(); } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java index 8f739d0c0566a..93961daf97b79 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java @@ -260,7 +260,10 @@ public void forward(final Record record, final String childName) { recordContext.offset(), recordContext.partition(), recordContext.topic(), - record.headers()); + record.headers(), + recordContext.sourceRawKey(), + recordContext.sourceRawValue() + ); } if (childName == null) { diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java index 5d245ef5f303e..1dddc55ca3c26 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java @@ -215,7 +215,9 @@ public void process(final Record record) { internalProcessorContext.recordContext().headers(), internalProcessorContext.currentNode().name(), internalProcessorContext.taskId(), - internalProcessorContext.recordContext().timestamp() + internalProcessorContext.recordContext().timestamp(), + internalProcessorContext.recordContext().sourceRawKey(), + internalProcessorContext.recordContext().sourceRawValue() ); final ProcessingExceptionHandler.ProcessingHandlerResponse response; diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java index 839baaad87528..8198645eb652e 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java @@ -24,6 +24,7 @@ import org.apache.kafka.streams.processor.api.RecordMetadata; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.Objects; import static java.nio.charset.StandardCharsets.UTF_8; @@ -37,6 +38,8 @@ public class ProcessorRecordContext implements RecordContext, RecordMetadata { private final String topic; private final int partition; private final Headers headers; + private byte[] sourceRawKey; + private byte[] sourceRawValue; public ProcessorRecordContext(final long timestamp, final long offset, @@ -48,6 +51,24 @@ public ProcessorRecordContext(final long timestamp, this.topic = topic; this.partition = partition; this.headers = Objects.requireNonNull(headers); + this.sourceRawKey = null; + this.sourceRawValue = null; + } + + public ProcessorRecordContext(final long timestamp, + final long offset, + final int partition, + final String topic, + final Headers headers, + final byte[] sourceRawKey, + final byte[] sourceRawValue) { + this.timestamp = timestamp; + this.offset = offset; + this.topic = topic; + this.partition = partition; + this.headers = Objects.requireNonNull(headers); + this.sourceRawKey = sourceRawKey; + this.sourceRawValue = sourceRawValue; } @Override @@ -75,6 +96,16 @@ public Headers headers() { return headers; } + @Override + public byte[] sourceRawKey() { + return sourceRawKey; + } + + @Override + public byte[] sourceRawValue() { + return sourceRawValue; + } + public long residentMemorySizeEstimate() { long size = 0; size += Long.BYTES; // value.context.timestamp @@ -176,6 +207,11 @@ public static ProcessorRecordContext deserialize(final ByteBuffer buffer) { return new ProcessorRecordContext(timestamp, offset, partition, topic, headers); } + public void freeRawRecord() { + this.sourceRawKey = null; + this.sourceRawValue = null; + } + @Override public boolean equals(final Object o) { if (this == o) { @@ -189,7 +225,9 @@ public boolean equals(final Object o) { offset == that.offset && partition == that.partition && Objects.equals(topic, that.topic) && - Objects.equals(headers, that.headers); + Objects.equals(headers, that.headers) && + Arrays.equals(sourceRawKey, that.sourceRawKey) && + Arrays.equals(sourceRawValue, that.sourceRawValue); } /** diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java index d47db7ea94261..89cbf4d4c7d4e 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java @@ -259,6 +259,10 @@ public void send(final String topic, final ProducerRecord serializedRecord = new ProducerRecord<>(topic, partition, timestamp, keyBytes, valBytes, headers); + // As many records could be in-flight, + // freeing raw records in the context to reduce memory pressure + freeRawInputRecordFromContext(context); + streamsProducer.send(serializedRecord, (metadata, exception) -> { try { // if there's already an exception record, skip logging offsets or new exceptions @@ -311,6 +315,12 @@ public void send(final String topic, }); } + private static void freeRawInputRecordFromContext(final InternalProcessorContext context) { + if (context != null && context.recordContext() != null) { + context.recordContext().freeRawRecord(); + } + } + private void handleException(final ProductionExceptionHandler.SerializationExceptionOrigin origin, final String topic, final K key, @@ -388,7 +398,9 @@ private DefaultErrorHandlerContext errorHandlerContext(final InternalProcessorCo recordContext.headers(), processorNodeId, taskId, - recordContext.timestamp() + recordContext.timestamp(), + context.recordContext().sourceRawKey(), + context.recordContext().sourceRawValue() ) : new DefaultErrorHandlerContext( context, @@ -398,7 +410,9 @@ private DefaultErrorHandlerContext errorHandlerContext(final InternalProcessorCo new RecordHeaders(), processorNodeId, taskId, - -1L + -1L, + null, + null ); } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java index 6f9fe989552f8..153ca2e02f1ee 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java @@ -95,7 +95,10 @@ public static void handleDeserializationFailure(final DeserializationExceptionHa rawRecord.headers(), sourceNodeName, processorContext.taskId(), - rawRecord.timestamp()); + rawRecord.timestamp(), + rawRecord.key(), + rawRecord.value() + ); final DeserializationHandlerResponse response; try { diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java index d38d7b625ae8e..faa90572ca524 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java @@ -243,7 +243,7 @@ private void updateHead() { lastCorruptedRecord = raw; continue; } - headRecord = new StampedRecord(deserialized, timestamp); + headRecord = new StampedRecord(deserialized, timestamp, raw.key(), raw.value()); headRecordSizeInBytes = consumerRecordSizeInBytes(raw); } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java index c8ed35a9a8f6c..dd0a1298b6767 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java @@ -23,8 +23,22 @@ public class StampedRecord extends Stamped> { + private final byte[] rawKey; + private final byte[] rawValue; + public StampedRecord(final ConsumerRecord record, final long timestamp) { super(record, timestamp); + this.rawKey = null; + this.rawValue = null; + } + + public StampedRecord(final ConsumerRecord record, + final long timestamp, + final byte[] rawKey, + final byte[] rawValue) { + super(record, timestamp); + this.rawKey = rawKey; + this.rawValue = rawValue; } public String topic() { @@ -55,8 +69,26 @@ public Headers headers() { return value.headers(); } + public byte[] rawKey() { + return rawKey; + } + + public byte[] rawValue() { + return rawValue; + } + @Override public String toString() { return value.toString() + ", timestamp = " + timestamp; } + + @Override + public boolean equals(final Object other) { + return super.equals(other); + } + + @Override + public int hashCode() { + return super.hashCode(); + } } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java index 93737d8228933..82e9c8d7fb110 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java @@ -856,7 +856,9 @@ private void doProcess(final long wallClockTime) { record.offset(), record.partition(), record.topic(), - record.headers() + record.headers(), + record.rawKey(), + record.rawValue() ); updateProcessorContext(currNode, wallClockTime, recordContext); @@ -938,7 +940,9 @@ record = null; recordContext.headers(), node.name(), id(), - recordContext.timestamp() + recordContext.timestamp(), + recordContext.sourceRawKey(), + recordContext.sourceRawValue() ); final ProcessingExceptionHandler.ProcessingHandlerResponse response; diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java index f59271920f5a7..83343d04494d6 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java @@ -277,7 +277,9 @@ private void putInternal(final Bytes key, internalContext.recordContext().offset(), internalContext.recordContext().timestamp(), internalContext.recordContext().partition(), - internalContext.recordContext().topic() + internalContext.recordContext().topic(), + internalContext.recordContext().sourceRawKey(), + internalContext.recordContext().sourceRawValue() ) ); diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java index 00dbaa5589b2d..ec0c1bd077d6f 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java @@ -140,7 +140,9 @@ public void put(final Windowed key, final byte[] value) { internalContext.recordContext().offset(), internalContext.recordContext().timestamp(), internalContext.recordContext().partition(), - internalContext.recordContext().topic() + internalContext.recordContext().topic(), + internalContext.recordContext().sourceRawKey(), + internalContext.recordContext().sourceRawValue() ); internalContext.cache().put(cacheName, cacheFunction.cacheKey(binaryKey), entry); diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java index f138ff9202a83..0432c1726cb3e 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java @@ -158,7 +158,9 @@ public synchronized void put(final Bytes key, internalContext.recordContext().offset(), internalContext.recordContext().timestamp(), internalContext.recordContext().partition(), - internalContext.recordContext().topic() + internalContext.recordContext().topic(), + internalContext.recordContext().sourceRawKey(), + internalContext.recordContext().sourceRawValue() ); internalContext.cache().put(cacheName, cacheFunction.cacheKey(keyBytes), entry); diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/LRUCacheEntry.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/LRUCacheEntry.java index f4233c7cb1120..0cbd79714ccd6 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/LRUCacheEntry.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/LRUCacheEntry.java @@ -32,7 +32,7 @@ class LRUCacheEntry { LRUCacheEntry(final byte[] value) { - this(value, new RecordHeaders(), false, -1, -1, -1, ""); + this(value, new RecordHeaders(), false, -1, -1, -1, "", null, null); } LRUCacheEntry(final byte[] value, @@ -41,8 +41,18 @@ class LRUCacheEntry { final long offset, final long timestamp, final int partition, - final String topic) { - final ProcessorRecordContext context = new ProcessorRecordContext(timestamp, offset, partition, topic, headers); + final String topic, + final byte[] rawKey, + final byte[] rawValue) { + final ProcessorRecordContext context = new ProcessorRecordContext( + timestamp, + offset, + partition, + topic, + headers, + rawKey, + rawValue + ); this.record = new ContextualRecord( value, diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimeOrderedCachingWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimeOrderedCachingWindowStore.java index 7f443c3e32ce8..646cbf2ca3557 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimeOrderedCachingWindowStore.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimeOrderedCachingWindowStore.java @@ -261,7 +261,9 @@ public synchronized void put(final Bytes key, internalContext.recordContext().offset(), internalContext.recordContext().timestamp(), internalContext.recordContext().partition(), - internalContext.recordContext().topic() + internalContext.recordContext().topic(), + internalContext.recordContext().sourceRawKey(), + internalContext.recordContext().sourceRawValue() ); // Put to index first so that base can be evicted later @@ -279,7 +281,9 @@ public synchronized void put(final Bytes key, internalContext.recordContext().offset(), internalContext.recordContext().timestamp(), internalContext.recordContext().partition(), - "" + "", + internalContext.recordContext().sourceRawKey(), + internalContext.recordContext().sourceRawValue() ); final Bytes indexKey = KeyFirstWindowKeySchema.toStoreKeyBinary(key, windowStartTimestamp, 0); internalContext.cache().put(cacheName, indexKeyCacheFunction.cacheKey(indexKey), emptyEntry); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java index 86f617e7f3483..5341cd25f0d5d 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java @@ -80,6 +80,8 @@ public class ProcessorNodeTest { private static final String NAME = "name"; private static final String KEY = "key"; private static final String VALUE = "value"; + private static final byte[] RAW_KEY = KEY.getBytes(); + private static final byte[] RAW_VALUE = VALUE.getBytes(); @Test public void shouldThrowStreamsExceptionIfExceptionCaughtDuringInit() { @@ -331,7 +333,9 @@ private InternalProcessorContext mockInternalProcessorContext() OFFSET, PARTITION, TOPIC, - new RecordHeaders())); + new RecordHeaders(), + RAW_KEY, + RAW_VALUE)); when(internalProcessorContext.currentNode()).thenReturn(new ProcessorNode<>(NAME)); return internalProcessorContext; @@ -359,6 +363,9 @@ public ProcessingExceptionHandler.ProcessingHandlerResponse handle(final ErrorHa assertEquals(internalProcessorContext.currentNode().name(), context.processorNodeId()); assertEquals(internalProcessorContext.taskId(), context.taskId()); assertEquals(internalProcessorContext.recordContext().timestamp(), context.timestamp()); + assertEquals(internalProcessorContext.recordContext().sourceRawKey(), context.sourceRawKey()); + assertEquals(internalProcessorContext.recordContext().sourceRawValue(), context.sourceRawValue()); + assertEquals(KEY, record.key()); assertEquals(VALUE, record.value()); assertInstanceOf(RuntimeException.class, exception); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java index b01b87ed85f82..4fb5f91ba7526 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java @@ -100,6 +100,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; @@ -1890,6 +1892,68 @@ public void shouldNotSendIfSendOfOtherTaskFailedInCallback() { )); } + @Test + public void shouldFreeRawRecordsInContextBeforeSending() { + final KafkaException exception = new KafkaException("KABOOM!"); + final byte[][] sourceRawData = new byte[][]{new byte[]{}, new byte[]{}}; + + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + getExceptionalStreamsProducerOnSend(exception), + new ProductionExceptionHandler() { + @Override + public void configure(final Map configs) { + + } + + @Override + public ProductionExceptionHandlerResponse handle(final ErrorHandlerContext context, final ProducerRecord record, final Exception exception) { + sourceRawData[0] = context.sourceRawKey(); + sourceRawData[1] = context.sourceRawValue(); + return ProductionExceptionHandlerResponse.CONTINUE; + } + }, + streamsMetrics, + topology + ); + + collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, sinkNodeName, context, streamPartitioner); + + assertNull(sourceRawData[0]); + assertNull(sourceRawData[1]); + } + + + @Test + public void shouldHaveRawDataDuringExceptionInSerialization() { + final byte[][] sourceRawData = new byte[][]{new byte[]{}, new byte[]{}}; + try (final ErrorStringSerializer errorSerializer = new ErrorStringSerializer()) { + final RecordCollector collector = newRecordCollector( + new ProductionExceptionHandler() { + @Override + @SuppressWarnings({"rawtypes", "unused"}) + public ProductionExceptionHandlerResponse handleSerializationException(final ErrorHandlerContext context, final ProducerRecord record, final Exception exception, final SerializationExceptionOrigin origin) { + sourceRawData[0] = context.sourceRawKey(); + sourceRawData[1] = context.sourceRawValue(); + return ProductionExceptionHandlerResponse.CONTINUE; + } + + @Override + public void configure(final Map configs) { + + } + } + ); + collector.initialize(); + + collector.send(topic, "hello", "val", null, 0, null, (Serializer) errorSerializer, stringSerializer, sinkNodeName, context); + + assertNotNull(sourceRawData[0]); + assertNotNull(sourceRawData[1]); + } + } + private RecordCollector newRecordCollector(final ProductionExceptionHandler productionExceptionHandler) { return new RecordCollectorImpl( logContext, diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java index 9a68e258c5dc1..fd138c7e71469 100644 --- a/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java @@ -44,6 +44,8 @@ public class NamedCacheTest { private final Headers headers = new RecordHeaders(new Header[]{new RecordHeader("key", "value".getBytes())}); private NamedCache cache; + private final byte[] rawKey = new byte[]{0}; + private final byte[] rawValue = new byte[]{0}; @BeforeEach public void setUp() { @@ -64,7 +66,7 @@ public void shouldKeepTrackOfMostRecentlyAndLeastRecentlyUsed() { final byte[] key = stringStringKeyValue.key.getBytes(); final byte[] value = stringStringKeyValue.value.getBytes(); cache.put(Bytes.wrap(key), - new LRUCacheEntry(value, new RecordHeaders(), true, 1, 1, 1, "")); + new LRUCacheEntry(value, new RecordHeaders(), true, 1, 1, 1, "", rawKey, rawValue)); final LRUCacheEntry head = cache.first(); final LRUCacheEntry tail = cache.last(); assertEquals(new String(head.value()), stringStringKeyValue.value); @@ -152,9 +154,9 @@ public void shouldEvictEldestEntry() { @Test public void shouldFlushDirtEntriesOnEviction() { final List flushed = new ArrayList<>(); - cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{10}, headers, true, 0, 0, 0, "")); + cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{10}, headers, true, 0, 0, 0, "", rawKey, rawValue)); cache.put(Bytes.wrap(new byte[]{1}), new LRUCacheEntry(new byte[]{20})); - cache.put(Bytes.wrap(new byte[]{2}), new LRUCacheEntry(new byte[]{30}, headers, true, 0, 0, 0, "")); + cache.put(Bytes.wrap(new byte[]{2}), new LRUCacheEntry(new byte[]{30}, headers, true, 0, 0, 0, "", rawKey, rawValue)); cache.setListener(flushed::addAll); @@ -176,16 +178,16 @@ public void shouldNotThrowNullPointerWhenCacheIsEmptyAndEvictionCalled() { @Test public void shouldThrowIllegalStateExceptionWhenTryingToOverwriteDirtyEntryWithCleanEntry() { - cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{10}, headers, true, 0, 0, 0, "")); + cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{10}, headers, true, 0, 0, 0, "", rawKey, rawValue)); assertThrows(IllegalStateException.class, () -> cache.put(Bytes.wrap(new byte[]{0}), - new LRUCacheEntry(new byte[]{10}, new RecordHeaders(), false, 0, 0, 0, ""))); + new LRUCacheEntry(new byte[]{10}, new RecordHeaders(), false, 0, 0, 0, "", rawKey, rawValue))); } @Test public void shouldRemoveDeletedValuesOnFlush() { cache.setListener(dirty -> { /* no-op */ }); - cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(null, headers, true, 0, 0, 0, "")); - cache.put(Bytes.wrap(new byte[]{1}), new LRUCacheEntry(new byte[]{20}, new RecordHeaders(), true, 0, 0, 0, "")); + cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(null, headers, true, 0, 0, 0, "", rawKey, rawValue)); + cache.put(Bytes.wrap(new byte[]{1}), new LRUCacheEntry(new byte[]{20}, new RecordHeaders(), true, 0, 0, 0, "", rawKey, rawValue)); cache.flush(); assertEquals(1, cache.size()); assertNotNull(cache.get(Bytes.wrap(new byte[]{1}))); @@ -193,7 +195,7 @@ public void shouldRemoveDeletedValuesOnFlush() { @Test public void shouldBeReentrantAndNotBreakLRU() { - final LRUCacheEntry dirty = new LRUCacheEntry(new byte[]{3}, new RecordHeaders(), true, 0, 0, 0, ""); + final LRUCacheEntry dirty = new LRUCacheEntry(new byte[]{3}, new RecordHeaders(), true, 0, 0, 0, "", rawKey, rawValue); final LRUCacheEntry clean = new LRUCacheEntry(new byte[]{3}); cache.put(Bytes.wrap(new byte[]{0}), dirty); cache.put(Bytes.wrap(new byte[]{1}), clean); @@ -236,7 +238,7 @@ public void shouldBeReentrantAndNotBreakLRU() { @Test public void shouldNotThrowIllegalArgumentAfterEvictingDirtyRecordAndThenPuttingNewRecordWithSameKey() { - final LRUCacheEntry dirty = new LRUCacheEntry(new byte[]{3}, new RecordHeaders(), true, 0, 0, 0, ""); + final LRUCacheEntry dirty = new LRUCacheEntry(new byte[]{3}, new RecordHeaders(), true, 0, 0, 0, "", rawKey, rawValue); final LRUCacheEntry clean = new LRUCacheEntry(new byte[]{3}); final Bytes key = Bytes.wrap(new byte[] {3}); cache.setListener(dirty1 -> cache.put(key, clean)); diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/ThreadCacheTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/ThreadCacheTest.java index 9e904a2ab2c7e..a1cc0cec6fcfa 100644 --- a/streams/src/test/java/org/apache/kafka/streams/state/internals/ThreadCacheTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/ThreadCacheTest.java @@ -48,6 +48,8 @@ public class ThreadCacheTest { final String namespace2 = "0.2-namespace"; private final LogContext logContext = new LogContext("testCache "); private final byte[][] bytes = new byte[][]{{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}, {10}}; + private final byte[] rawKey = new byte[]{0}; + private final byte[] rawValue = new byte[]{0}; @Test public void basicPutGet() { @@ -65,7 +67,7 @@ public void basicPutGet() { for (final KeyValue kvToInsert : toInsert) { final Bytes key = Bytes.wrap(kvToInsert.key.getBytes()); final byte[] value = kvToInsert.value.getBytes(); - cache.put(namespace, key, new LRUCacheEntry(value, new RecordHeaders(), true, 1L, 1L, 1, "")); + cache.put(namespace, key, new LRUCacheEntry(value, new RecordHeaders(), true, 1L, 1L, 1, "", rawKey, rawValue)); } for (final KeyValue kvToInsert : toInsert) { @@ -98,7 +100,7 @@ private void checkOverheads(final double entryFactor, final String keyStr = "K" + i; final Bytes key = Bytes.wrap(keyStr.getBytes()); final byte[] value = new byte[valueSizeBytes]; - cache.put(namespace, key, new LRUCacheEntry(value, new RecordHeaders(), true, 1L, 1L, 1, "")); + cache.put(namespace, key, new LRUCacheEntry(value, new RecordHeaders(), true, 1L, 1L, 1, "", rawKey, rawValue)); } @@ -176,7 +178,7 @@ public void evict() { for (final KeyValue kvToInsert : toInsert) { final Bytes key = Bytes.wrap(kvToInsert.key.getBytes()); final byte[] value = kvToInsert.value.getBytes(); - cache.put(namespace, key, new LRUCacheEntry(value, new RecordHeaders(), true, 1, 1, 1, "")); + cache.put(namespace, key, new LRUCacheEntry(value, new RecordHeaders(), true, 1, 1, 1, "", rawKey, rawValue)); } for (int i = 0; i < expected.size(); i++) { @@ -617,7 +619,7 @@ public void shouldResizeAndShrink() { } private LRUCacheEntry dirtyEntry(final byte[] key) { - return new LRUCacheEntry(key, new RecordHeaders(), true, -1, -1, -1, ""); + return new LRUCacheEntry(key, new RecordHeaders(), true, -1, -1, -1, "", rawKey, rawValue); } private LRUCacheEntry cleanEntry(final byte[] key) { diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedCachingPersistentWindowStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedCachingPersistentWindowStoreTest.java index 7aba24344574c..ffa509d518871 100644 --- a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedCachingPersistentWindowStoreTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedCachingPersistentWindowStoreTest.java @@ -938,7 +938,10 @@ public void shouldSkipNonExistBaseKeyInCache(final boolean hasIndex) { context.recordContext().offset(), context.recordContext().timestamp(), context.recordContext().partition(), - "") + "", + context.recordContext().sourceRawKey(), + context.recordContext().sourceRawValue() + ) ); underlyingStore.put(key, value, 1); diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedWindowStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedWindowStoreTest.java index 9eb9ec21b5ee7..9d0db9bae0fbb 100644 --- a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedWindowStoreTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedWindowStoreTest.java @@ -944,7 +944,9 @@ public void shouldSkipNonExistBaseKeyInCache(final boolean hasIndex) { context.recordContext().offset(), context.recordContext().timestamp(), context.recordContext().partition(), - "" + "", + context.recordContext().sourceRawKey(), + context.recordContext().sourceRawValue() ) ); diff --git a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java index 228df8d63a1ac..ed68c86c49020 100644 --- a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java +++ b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java @@ -56,6 +56,7 @@ import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener; import java.io.File; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; @@ -244,7 +245,9 @@ public InternalMockProcessorContext(final File stateDir, 0, 0, "topic", - new RecordHeaders() + new RecordHeaders(), + "sourceKey".getBytes(StandardCharsets.UTF_8), + "sourceValue".getBytes(StandardCharsets.UTF_8) ); }