diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowBytesStoreSupplier.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowBytesStoreSupplier.java index a6709ae64e210..ace4de0e74ef8 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowBytesStoreSupplier.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowBytesStoreSupplier.java @@ -16,7 +16,6 @@ */ package org.apache.kafka.streams.state.internals; -import org.apache.kafka.common.serialization.Serdes; import org.apache.kafka.common.utils.Bytes; import org.apache.kafka.streams.state.WindowBytesStoreSupplier; import org.apache.kafka.streams.state.WindowStore; @@ -44,13 +43,11 @@ public String name() { @Override public WindowStore get() { - return new InMemoryWindowStore<>(name, - Serdes.Bytes(), - Serdes.ByteArray(), - retentionPeriod, - windowSize, - retainDuplicates, - metricsScope()); + return new InMemoryWindowStore(name, + retentionPeriod, + windowSize, + retainDuplicates, + metricsScope()); } @Override diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowStore.java index 6e9b96bd3cf04..77820c57718ee 100644 --- a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowStore.java +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowStore.java @@ -16,18 +16,17 @@ */ package org.apache.kafka.streams.state.internals; +import java.nio.ByteBuffer; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.common.metrics.Sensor; -import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; import org.apache.kafka.streams.KeyValue; import org.apache.kafka.streams.kstream.Windowed; import org.apache.kafka.streams.kstream.internals.TimeWindow; import org.apache.kafka.streams.processor.ProcessorContext; import org.apache.kafka.streams.processor.StateStore; import org.apache.kafka.streams.processor.internals.InternalProcessorContext; -import org.apache.kafka.streams.processor.internals.ProcessorStateManager; import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; -import org.apache.kafka.streams.state.StateSerdes; import org.apache.kafka.streams.state.WindowStore; import org.apache.kafka.streams.state.WindowStoreIterator; import org.apache.kafka.streams.state.KeyValueIterator; @@ -44,18 +43,16 @@ import java.util.TreeMap; import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateAndCount; -import static org.apache.kafka.streams.state.internals.WindowKeySchema.extractStoreKey; +import static org.apache.kafka.streams.state.internals.WindowKeySchema.extractStoreKeyBytes; import static org.apache.kafka.streams.state.internals.WindowKeySchema.extractStoreTimestamp; -public class InMemoryWindowStore, V> implements WindowStore { +public class InMemoryWindowStore implements WindowStore { private static final Logger LOG = LoggerFactory.getLogger(InMemoryWindowStore.class); + private static final int SEQNUM_SIZE = 4; private final String name; - private final Serde keySerde; - private final Serde valueSerde; private final String metricScope; - private StateSerdes serdes; private InternalProcessorContext context; private Sensor expiredRecordSensor; private int seqnum = 0; @@ -65,20 +62,16 @@ public class InMemoryWindowStore, V> implements WindowSt private final long windowSize; private final boolean retainDuplicates; - private final NavigableMap, V>> segmentMap; + private final NavigableMap> segmentMap; private volatile boolean open = false; InMemoryWindowStore(final String name, - final Serde keySerde, - final Serde valueSerde, final long retentionPeriod, final long windowSize, final boolean retainDuplicates, final String metricScope) { this.name = name; - this.keySerde = keySerde; - this.valueSerde = valueSerde; this.retentionPeriod = retentionPeriod; this.windowSize = windowSize; this.retainDuplicates = retainDuplicates; @@ -97,12 +90,6 @@ public String name() { public void init(final ProcessorContext context, final StateStore root) { this.context = (InternalProcessorContext) context; - // construct the serde - this.serdes = new StateSerdes<>( - ProcessorStateManager.storeChangelogTopic(context.applicationId(), name), - keySerde == null ? (Serde) context.keySerde() : keySerde, - valueSerde == null ? (Serde) context.valueSerde() : valueSerde); - final StreamsMetricsImpl metrics = this.context.metrics(); final String taskName = context.taskId().toString(); expiredRecordSensor = metrics.storeLevelSensor( @@ -120,33 +107,35 @@ public void init(final ProcessorContext context, final StateStore root) { if (root != null) { context.register(root, (key, value) -> { - put(extractStoreKey(key, serdes), serdes.valueFrom(value), extractStoreTimestamp(key)); + put(Bytes.wrap(extractStoreKeyBytes(key)), value, extractStoreTimestamp(key)); }); } this.open = true; } @Override - public void put(final K key, final V value) { + public void put(final Bytes key, final byte[] value) { put(key, value, context.timestamp()); } @Override - public void put(final K key, final V value, final long windowStartTimestamp) { + public void put(final Bytes key, final byte[] value, final long windowStartTimestamp) { removeExpiredSegments(); maybeUpdateSeqnumForDups(); this.observedStreamTime = Math.max(this.observedStreamTime, windowStartTimestamp); + final Bytes keyBytes = retainDuplicates ? wrapForDups(key, seqnum) : key; + if (windowStartTimestamp <= this.observedStreamTime - this.retentionPeriod) { expiredRecordSensor.record(); LOG.debug("Skipping record for expired segment."); } else { if (value != null) { this.segmentMap.computeIfAbsent(windowStartTimestamp, t -> new TreeMap<>()); - this.segmentMap.get(windowStartTimestamp).put(new WrappedK<>(key, seqnum), value); + this.segmentMap.get(windowStartTimestamp).put(keyBytes, value); } else { this.segmentMap.computeIfPresent(windowStartTimestamp, (t, kvMap) -> { - kvMap.remove(new WrappedK<>(key, seqnum)); + kvMap.remove(keyBytes); return kvMap; }); } @@ -154,77 +143,79 @@ public void put(final K key, final V value, final long windowStartTimestamp) { } @Override - public V fetch(final K key, final long windowStartTimestamp) { + public byte[] fetch(final Bytes key, final long windowStartTimestamp) { removeExpiredSegments(); - final NavigableMap, V> kvMap = this.segmentMap.get(windowStartTimestamp); + final NavigableMap kvMap = this.segmentMap.get(windowStartTimestamp); if (kvMap == null) { return null; } else { - return kvMap.get(new WrappedK<>(key, seqnum)); + return kvMap.get(key); } } @Deprecated @Override - public WindowStoreIterator fetch(final K key, final long timeFrom, final long timeTo) { + public WindowStoreIterator fetch(final Bytes key, final long timeFrom, final long timeTo) { removeExpiredSegments(); - final List> records = retainDuplicates ? fetchWithDuplicates(key, timeFrom, timeTo) : fetchUnique(key, timeFrom, timeTo); + final List> records = retainDuplicates ? fetchWithDuplicates(key, timeFrom, timeTo) : fetchUnique(key, timeFrom, timeTo); - return new InMemoryWindowStoreIterator<>(records.listIterator()); + return new InMemoryWindowStoreIterator(records.listIterator()); } @Deprecated @Override - public KeyValueIterator, V> fetch(final K from, final K to, final long timeFrom, final long timeTo) { + public KeyValueIterator, byte[]> fetch(final Bytes from, + final Bytes to, + final long timeFrom, + final long timeTo) { removeExpiredSegments(); - final List, V>> returnSet = new LinkedList<>(); + final List, byte[]>> returnSet = new LinkedList<>(); // add one b/c records expire exactly retentionPeriod ms after created final long minTime = Math.max(timeFrom, this.observedStreamTime - this.retentionPeriod + 1); - final WrappedK keyFrom = new WrappedK<>(from, 0); - final WrappedK keyTo = new WrappedK<>(to, Integer.MAX_VALUE); + final Bytes keyFrom = retainDuplicates ? wrapForDups(from, 0) : from; + final Bytes keyTo = retainDuplicates ? wrapForDups(to, Integer.MAX_VALUE) : to; - for (final Map.Entry, V>> segmentMapEntry : this.segmentMap.subMap(minTime, true, timeTo, true).entrySet()) { - for (final Map.Entry, V> kvMapEntry : segmentMapEntry.getValue().subMap(keyFrom, true, keyTo, true).entrySet()) { - final WrappedK wrappedKey = kvMapEntry.getKey(); - returnSet.add(getWindowedKeyValue(wrappedKey.getKey(), segmentMapEntry.getKey(), kvMapEntry.getValue())); + for (final Map.Entry> segmentMapEntry : this.segmentMap.subMap(minTime, true, timeTo, true).entrySet()) { + for (final Map.Entry kvMapEntry : segmentMapEntry.getValue().subMap(keyFrom, true, keyTo, true).entrySet()) { + final Bytes keyBytes = retainDuplicates ? getKey(kvMapEntry.getKey()) : kvMapEntry.getKey(); + returnSet.add(getWindowedKeyValue(keyBytes, segmentMapEntry.getKey(), kvMapEntry.getValue())); } } - return new InMemoryWindowedKeyValueIterator<>(returnSet.listIterator()); + return new InMemoryWindowedKeyValueIterator(returnSet.listIterator()); } @Deprecated @Override - public KeyValueIterator, V> fetchAll(final long timeFrom, final long timeTo) { + public KeyValueIterator, byte[]> fetchAll(final long timeFrom, final long timeTo) { removeExpiredSegments(); - final List, V>> returnSet = new LinkedList<>(); + final List, byte[]>> returnSet = new LinkedList<>(); // add one b/c records expire exactly retentionPeriod ms after created final long minTime = Math.max(timeFrom, this.observedStreamTime - this.retentionPeriod + 1); - for (final Map.Entry, V>> segmentMapEntry : this.segmentMap.subMap(minTime, true, timeTo, true).entrySet()) { - for (final Map.Entry, V> kvMapEntry : segmentMapEntry.getValue().entrySet()) { - final WrappedK wrappedKey = kvMapEntry.getKey(); - returnSet.add(getWindowedKeyValue(wrappedKey.getKey(), segmentMapEntry.getKey(), kvMapEntry.getValue())); + for (final Map.Entry> segmentMapEntry : this.segmentMap.subMap(minTime, true, timeTo, true).entrySet()) { + for (final Map.Entry kvMapEntry : segmentMapEntry.getValue().entrySet()) { + final Bytes keyBytes = retainDuplicates ? getKey(kvMapEntry.getKey()) : kvMapEntry.getKey(); + returnSet.add(getWindowedKeyValue(keyBytes, segmentMapEntry.getKey(), kvMapEntry.getValue())); } } - return new InMemoryWindowedKeyValueIterator<>(returnSet.listIterator()); + return new InMemoryWindowedKeyValueIterator(returnSet.listIterator()); } @Override - public KeyValueIterator, V> all() { + public KeyValueIterator, byte[]> all() { removeExpiredSegments(); - final List, V>> returnSet = new LinkedList<>(); + final List, byte[]>> returnSet = new LinkedList<>(); - for (final Entry, V>> segmentMapEntry : this.segmentMap.entrySet()) { - for (final Entry, V> kvMapEntry : segmentMapEntry.getValue().entrySet()) { - final WrappedK wrappedKey = kvMapEntry.getKey(); - returnSet.add(getWindowedKeyValue(wrappedKey.getKey(), segmentMapEntry.getKey(), - kvMapEntry.getValue())); + for (final Entry> segmentMapEntry : this.segmentMap.entrySet()) { + for (final Entry kvMapEntry : segmentMapEntry.getValue().entrySet()) { + final Bytes keyBytes = retainDuplicates ? getKey(kvMapEntry.getKey()) : kvMapEntry.getKey(); + returnSet.add(getWindowedKeyValue(keyBytes, segmentMapEntry.getKey(), kvMapEntry.getValue())); } } - return new InMemoryWindowedKeyValueIterator<>(returnSet.listIterator()); + return new InMemoryWindowedKeyValueIterator(returnSet.listIterator()); } @Override @@ -248,14 +239,14 @@ public void close() { this.open = false; } - private List> fetchUnique(final K key, final long timeFrom, final long timeTo) { - final List> returnSet = new LinkedList<>(); + private List> fetchUnique(final Bytes key, final long timeFrom, final long timeTo) { + final List> returnSet = new LinkedList<>(); // add one b/c records expire exactly retentionPeriod ms after created final long minTime = Math.max(timeFrom, this.observedStreamTime - this.retentionPeriod + 1); - for (final Map.Entry, V>> segmentMapEntry : this.segmentMap.subMap(minTime, true, timeTo, true).entrySet()) { - final V value = segmentMapEntry.getValue().get(new WrappedK<>(key, seqnum)); + for (final Map.Entry> segmentMapEntry : this.segmentMap.subMap(minTime, true, timeTo, true).entrySet()) { + final byte[] value = segmentMapEntry.getValue().get(key); if (value != null) { returnSet.add(new KeyValue<>(segmentMapEntry.getKey(), value)); } @@ -263,16 +254,16 @@ private List> fetchUnique(final K key, final long timeFrom, fi return returnSet; } - private List> fetchWithDuplicates(final K key, final long timeFrom, final long timeTo) { - final List> returnSet = new LinkedList<>(); + private List> fetchWithDuplicates(final Bytes key, final long timeFrom, final long timeTo) { + final List> returnSet = new LinkedList<>(); // add one b/c records expire exactly retentionPeriod ms after created final long minTime = Math.max(timeFrom, this.observedStreamTime - this.retentionPeriod + 1); - final WrappedK keyFrom = new WrappedK<>(key, 0); - final WrappedK keyTo = new WrappedK<>(key, Integer.MAX_VALUE); + final Bytes keyFrom = wrapForDups(key, 0); + final Bytes keyTo = wrapForDups(key, Integer.MAX_VALUE); - for (final Map.Entry, V>> segmentMapEntry : this.segmentMap.subMap(minTime, true, timeTo, true).entrySet()) { - for (final Map.Entry, V> kvMapEntry : segmentMapEntry.getValue().subMap(keyFrom, true, keyTo, true).entrySet()) { + for (final Map.Entry> segmentMapEntry : this.segmentMap.subMap(minTime, true, timeTo, true).entrySet()) { + for (final Map.Entry kvMapEntry : segmentMapEntry.getValue().subMap(keyFrom, true, keyTo, true).entrySet()) { returnSet.add(new KeyValue<>(segmentMapEntry.getKey(), kvMapEntry.getValue())); } } @@ -284,8 +275,10 @@ private void removeExpiredSegments() { this.segmentMap.headMap(minLiveTime, true).clear(); } - private KeyValue, V> getWindowedKeyValue(final K key, final long startTimestamp, final V value) { - final Windowed windowedK = new Windowed<>(key, new TimeWindow(startTimestamp, startTimestamp + windowSize)); + private KeyValue, byte[]> getWindowedKeyValue(final Bytes key, + final long startTimestamp, + final byte[] value) { + final Windowed windowedK = new Windowed<>(key, new TimeWindow(startTimestamp, startTimestamp + windowSize)); return new KeyValue<>(windowedK, value); } @@ -295,34 +288,26 @@ private void maybeUpdateSeqnumForDups() { } } - private static class WrappedK> implements Comparable> { - private final K key; - private final int seqnum; + private static Bytes wrapForDups(final Bytes key, final int seqnum) { + final ByteBuffer buf = ByteBuffer.allocate(key.get().length + SEQNUM_SIZE); + buf.put(key.get()); + buf.putInt(seqnum); - WrappedK(final K key, final int seqnum) { - this.key = key; - this.seqnum = seqnum; - } + return Bytes.wrap(buf.array()); + } - public K getKey() { - return this.key; - } + private static Bytes getKey(final Bytes keyBytes) { + final byte[] bytes = new byte[keyBytes.get().length - SEQNUM_SIZE]; + System.arraycopy(keyBytes.get(), 0, bytes, 0, bytes.length); + return Bytes.wrap(bytes); - public int compareTo(final WrappedK k) { - final int compareKeys = this.key.compareTo(k.key); - if (compareKeys == 0) { - return this.seqnum - k.seqnum; - } else { - return compareKeys; - } - } } - private static class InMemoryWindowStoreIterator implements WindowStoreIterator { + private static class InMemoryWindowStoreIterator implements WindowStoreIterator { - private ListIterator> iterator; + private ListIterator> iterator; - InMemoryWindowStoreIterator(final ListIterator> iterator) { + InMemoryWindowStoreIterator(final ListIterator> iterator) { this.iterator = iterator; } @@ -332,7 +317,7 @@ public boolean hasNext() { } @Override - public KeyValue next() { + public KeyValue next() { return iterator.next(); } @@ -353,11 +338,11 @@ public void close() { } } - private static class InMemoryWindowedKeyValueIterator implements KeyValueIterator, V> { + private static class InMemoryWindowedKeyValueIterator implements KeyValueIterator, byte[]> { - ListIterator, V>> iterator; + ListIterator, byte[]>> iterator; - InMemoryWindowedKeyValueIterator(final ListIterator, V>> iterator) { + InMemoryWindowedKeyValueIterator(final ListIterator, byte[]>> iterator) { this.iterator = iterator; } @@ -367,16 +352,16 @@ public boolean hasNext() { } @Override - public KeyValue, V> next() { + public KeyValue, byte[]> next() { return iterator.next(); } @Override - public Windowed peekNextKey() { + public Windowed peekNextKey() { if (!hasNext()) { throw new NoSuchElementException(); } else { - final Windowed next = iterator.next().key; + final Windowed next = iterator.next().key; iterator.previous(); return next; }