From 176ba3c0ea7c68002eca6baa852f8f7a92ca7cbb Mon Sep 17 00:00:00 2001 From: "Colin P. Mccabe" Date: Fri, 22 Nov 2019 16:16:19 -0800 Subject: [PATCH] Reorganize ImplicitLinkedHashCollection and the generated collections Unite ImplicitLinkedHashCollection and ImplicitLinkedHashMultiCollection into a single class. It's simpler this way, and the code that needs to avoid duplicates can do it through the API. Fix the problem where inserting the same object into a collection twice causes corruption. This happened because due to the implicit linked list, there was only one set of previous and next pointers. We get around this in this PR by duplicating the inserted object when this situation arises. Generated Message classes now have a duplicate() function which creates a duplicate of the existing object, and a copy() method which sets all values to those of another object of the same type. Previously, we had to resort to clunky serializing and then deserializing to copy message objects. Generated Message classes now always have an equals() function which tests every field. Previously, message classes used as keys had special equals functions which only compared their key fields, which was confusing. We still support searching by the key because the find and findAll functions now take a comparator argument. --- .../utils/ImplicitLinkedHashCollection.java | 279 +++++++++++++----- .../ImplicitLinkedHashMultiCollection.java | 142 --------- .../ImplicitLinkedHashCollectionTest.java | 174 ++++++++--- ...ImplicitLinkedHashMultiCollectionTest.java | 173 ----------- .../scala/kafka/server/FetchSession.scala | 13 +- .../kafka/message/MessageDataGenerator.java | 179 +++++++++-- .../kafka/message/MessageGenerator.java | 7 +- 7 files changed, 505 insertions(+), 462 deletions(-) delete mode 100644 clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashMultiCollection.java delete mode 100644 clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashMultiCollectionTest.java diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashCollection.java b/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashCollection.java index fba7d7ac584ec..f95eb9e603b07 100644 --- a/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashCollection.java +++ b/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashCollection.java @@ -20,6 +20,8 @@ import java.util.AbstractCollection; import java.util.AbstractSequentialList; import java.util.AbstractSet; +import java.util.ArrayList; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.ListIterator; @@ -27,21 +29,22 @@ import java.util.Set; /** - * A memory-efficient hash set which tracks the order of insertion of elements. + * A memory efficient collection which supports O(1) lookup of elements, and also tracks their + * order of insertion. Unlike in a Set or a Map, multiple elements can be inserted which are equal + * to each other. The type of contained objects must be a subclass of + * ImplicitLinkedHashCollection#Element. All elements must implement hashCode and equals. * - * Like java.util.LinkedHashSet, this collection maintains a linked list of elements. - * However, rather than using a separate linked list, this collection embeds the next - * and previous fields into the elements themselves. This reduces memory consumption, - * because it means that we only have to store one Java object per element, rather - * than multiple. + * The internal data structure is a hash table whose elements form a linked list. Rather than + * using a separate linked list, this collection embeds the "next" and "previous fields into the + * elements themselves. This reduces memory consumption, because it means that we only have to + * store one Java object per element, rather than multiple. * - * The next and previous fields are stored as array indices rather than pointers. - * This ensures that the fields only take 32 bits, even when pointers are 64 bits. - * It also makes the garbage collector's job easier, because it reduces the number of - * pointers that it must chase. + * The next and previous fields are stored as array indices rather than pointers. This ensures + * that the fields only take 32 bits, even when pointers are 64 bits. It also makes the garbage + * collector's job easier, because it reduces the number of pointers that it must chase. * - * This class uses linear probing. Unlike HashMap (but like HashTable), we don't force - * the size to be a power of 2. This saves memory. + * This class uses linear probing. Unlike HashMap (but like HashTable), we don't force the size to + * be a power of 2. This saves memory. * * This set does not allow null elements. It does not have internal synchronization. */ @@ -51,6 +54,40 @@ public interface Element { void setPrev(int prev); int next(); void setNext(int next); + Element duplicate(); + } + + /** + * Compares two elements. + * + * If compare(x, y) == true, then x.hashCode() == y.hashCode() must also be true. + */ + public interface Comparator { + boolean compare(T x, T y); + } + + /** + * Compares two elements using Object#equals(). + */ + public static class ObjectEqualityComparator implements Comparator { + public static final ObjectEqualityComparator INSTANCE = new ObjectEqualityComparator<>(); + + @Override + public boolean compare(T x, T y) { + return x.equals(y); + } + } + + /** + * Compares two elements using reference equality. + */ + public static class ReferenceEqualityComparator implements Comparator { + public static final ReferenceEqualityComparator INSTANCE = new ReferenceEqualityComparator<>(); + + @Override + public boolean compare(T x, T y) { + return x == y; + } } /** @@ -100,6 +137,11 @@ public int next() { public void setNext(int next) { this.next = next; } + + @Override + public Element duplicate() { + return new HeadElement(); + } } private static Element indexToElement(Element head, Element[] elements, int index) { @@ -288,32 +330,34 @@ private ListIterator listIterator(int index) { return new ImplicitLinkedHashCollectionIterator(index); } - final int slot(Element[] curElements, Object e) { + final static int slot(Element[] curElements, Object e) { return (e.hashCode() & 0x7fffffff) % curElements.length; } /** - * Find an element matching an example element. + * Find the index of an element matching a given target element. * * Using the element's hash code, we can look up the slot where it belongs. * However, it may not have ended up in exactly this slot, due to a collision. * Therefore, we must search forward in the array until we hit a null, before * concluding that the element is not present. * - * @param key The element to match. + * @param target The element to match. + * @param comparator The comparator to use. * @return The match index, or INVALID_INDEX if no match was found. */ - final private int findIndexOfEqualElement(Object key) { - if (key == null || size == 0) { + @SuppressWarnings("unchecked") + final private int findIndex(Object target, Comparator comparator) { + if (target == null || size == 0) { return INVALID_INDEX; } - int slot = slot(elements, key); + int slot = slot(elements, target); for (int seen = 0; seen < elements.length; seen++) { Element element = elements[slot]; if (element == null) { return INVALID_INDEX; } - if (key.equals(element)) { + if (comparator.compare((E) target, (E) element)) { return slot; } slot = (slot + 1) % elements.length; @@ -322,14 +366,26 @@ final private int findIndexOfEqualElement(Object key) { } /** - * An element e in the collection such that e.equals(key) and - * e.hashCode() == key.hashCode(). + * Return the first element e in the collection such that target.equals(e). + * + * @param target The element to match. + * + * @return The matching element, or null if there were none. + */ + public E find(E target) { + return find(target, ObjectEqualityComparator.INSTANCE); + } + + /** + * Return the first element e in the collection such that comparator.compare(target, e) == true. * - * @param key The element to match. - * @return The matching element, or null if there were none. + * @param target The element to match. + * @param comparator How to compare the two elements. + * + * @return The matching element, or null if there were none. */ - final public E find(E key) { - int index = findIndexOfEqualElement(key); + public E find(E target, Comparator comparator) { + int index = findIndex(target, comparator); if (index == INVALID_INDEX) { return null; } @@ -338,11 +394,51 @@ final public E find(E key) { return result; } + /** + * Returns all of the elements e in the collection such that + * target.equals(e). + * + * @param target The element to match. + * + * @return The matching element, or null if there were none. + */ + public List findAll(E target) { + return findAll(target, ObjectEqualityComparator.INSTANCE); + } + + /** + * Returns all of the elements e in the collection such that + * comparator.compare(e, target) == true. + * + * @param target The element to match. + * + * @return All of the matching elements. + */ + @SuppressWarnings("unchecked") + public List findAll(E target, Comparator comparator) { + if (target == null || size == 0) { + return Collections.emptyList(); + } + ArrayList results = new ArrayList<>(); + int slot = slot(elements, target); + for (int seen = 0; seen < elements.length; seen++) { + E element = (E) elements[slot]; + if (element == null) { + break; + } + if (comparator.compare(target, element)) { + results.add(element); + } + slot = (slot + 1) % elements.length; + } + return results; + } + /** * Returns the number of elements in the set. */ @Override - final public int size() { + public int size() { return size; } @@ -350,11 +446,11 @@ final public int size() { * Returns true if there is at least one element e in the collection such * that key.equals(e) and key.hashCode() == e.hashCode(). * - * @param key The object to try to match. + * @param target The object to try to match. */ @Override - final public boolean contains(Object key) { - return findIndexOfEqualElement(key) != INVALID_INDEX; + public boolean contains(Object target) { + return findIndex(target, ObjectEqualityComparator.INSTANCE) != INVALID_INDEX; } private static int calculateCapacity(int expectedNumElements) { @@ -372,54 +468,92 @@ private static int calculateCapacity(int expectedNumElements) { * * @param newElement The new element. * - * @return True if the element was added to the collection; - * false if it was not, because there was an existing equal element. + * @return True if the element was added to the collection. + * False if the element could not be added because it was null. */ @Override - final public boolean add(E newElement) { + public boolean add(E newElement) { + return addOrReplace(newElement, ReferenceEqualityComparator.INSTANCE); + } + + /** + * Add a new element to the collection. + * + * @param newElement The new element. + * @param comparator A comparator which will be used to determine if any object is + * similar enough to the new element to be replaced. + * + * @return True if the element was added to the collection. + * False if the element could not be added because it was null. + */ + final public boolean addOrReplace(E newElement, Comparator comparator) { if (newElement == null) { return false; } if ((size + 1) >= elements.length / 2) { changeCapacity(calculateCapacity(elements.length)); } - int slot = addInternal(newElement, elements); - if (slot >= 0) { - addToListTail(head, elements, slot); + if (addInternal(head, newElement, elements, comparator)) { size++; - return true; - } - return false; - } - - final public void mustAdd(E newElement) { - if (!add(newElement)) { - throw new RuntimeException("Unable to add " + newElement); } + return true; } /** * Adds a new element to the appropriate place in the elements array. * + * @param head The list head. * @param newElement The new element to add. * @param addElements The elements array. - * @return The index at which the element was inserted, or INVALID_INDEX - * if the element could not be inserted. + * @param comparator A comparator which will be used to determine if any object is + * similar enough to the new element to be replaced. + * + * @returns True if the size of the collection has increased. */ - int addInternal(Element newElement, Element[] addElements) { + @SuppressWarnings("unchecked") + static boolean addInternal(Element head, + Element newElement, + Element[] addElements, + Comparator comparator) { int slot = slot(addElements, newElement); + int bestSlot = INVALID_INDEX; for (int seen = 0; seen < addElements.length; seen++) { Element element = addElements[slot]; - if (element == null) { - addElements[slot] = newElement; - return slot; - } - if (element.equals(newElement)) { - return INVALID_INDEX; + if (element == newElement) { + // If we find that this object has already been added to the collection, + // create a clone of the object and add the clone instead. This is necessary + // because there is only one set of previous and next pointers contained + // in each element. Therefore, the same Java object cannot possibly appear + // more than once in the list. + newElement = newElement.duplicate(); + } else if (element == null) { + // When we hit a null, we know that we have seen all the possible values + // that might be possible to replace with the new element we are adding. + // This is because of the denseness invariant: if an element E should be + // in slot S but ends up in slot T, instead, there will never be a null + // between S and T. + if (bestSlot == INVALID_INDEX) { + bestSlot = slot; + } + break; + } else if (comparator.compare((E) newElement, (E) element)) { + // If the current element can be replaced by the new element we are adding, + // mark it down as the best slot we've found so far. We don't do the + // replacement immediately because we want to see all the possible values + // to make sure that the element we are adding has not already been added. + // That requires iterating until we hit a null. + bestSlot = slot; } slot = (slot + 1) % addElements.length; } - throw new RuntimeException("Not enough hash table slots to add a new element."); + boolean growing = true; + if (addElements[bestSlot] != null) { + removeFromList(head, addElements, bestSlot); + growing = false; + } + addElements[bestSlot] = newElement; + addToListTail(head, addElements, bestSlot); + return growing; } private void changeCapacity(int newCapacity) { @@ -429,8 +563,7 @@ private void changeCapacity(int newCapacity) { for (Iterator iter = iterator(); iter.hasNext(); ) { Element element = iter.next(); iter.remove(); - int newSlot = addInternal(element, newElements); - addToListTail(newHead, newElements, newSlot); + addInternal(newHead, element, newElements, ReferenceEqualityComparator.INSTANCE); } this.elements = newElements; this.head = newHead; @@ -438,15 +571,25 @@ private void changeCapacity(int newCapacity) { } /** - * Remove the first element e such that key.equals(e) - * and key.hashCode == e.hashCode. + * Remove an element from the collection, using object equality semantics. * - * @param key The object to try to match. - * @return True if an element was removed; false otherwise. + * @param target The object to try to remove. + * @return True if an element was removed; false otherwise. */ @Override - final public boolean remove(Object key) { - int slot = findElementToRemove(key); + final public boolean remove(Object target) { + return remove(target, ObjectEqualityComparator.INSTANCE); + } + + /** + * Remove an element from the collection, using the given comparator to test equality. + * + * @param target The object to try to remove. + * @param comparator The comparator to use. + * @return True if an element was removed; false otherwise. + */ + final public boolean remove(Object target, Comparator comparator) { + int slot = findIndex(target, comparator); if (slot == INVALID_INDEX) { return false; } @@ -454,10 +597,6 @@ final public boolean remove(Object key) { return true; } - int findElementToRemove(Object key) { - return findIndexOfEqualElement(key); - } - /** * Remove an element in a particular slot. * @@ -539,7 +678,7 @@ public ImplicitLinkedHashCollection(int expectedNumElements) { public ImplicitLinkedHashCollection(Iterator iter) { clear(0); while (iter.hasNext()) { - mustAdd(iter.next()); + add(iter.next()); } } @@ -576,14 +715,6 @@ final public void clear(int expectedNumElements) { * {@code ImplicitLinkedHashCollectionListIterator} iterates over the elements * in insertion order, it is sufficient to call {@code valuesList.equals}. * - * Note that {@link ImplicitLinkedHashMultiCollection} does not override - * {@code equals} and uses this method as well. This means that two - * {@code ImplicitLinkedHashMultiCollection} objects will be considered equal even - * if they each contain two elements A and B such that A.equals(B) but A != B and - * A and B have switched insertion positions between the two collections. This - * is an acceptable definition of equality, because the collections are still - * equal in terms of the order and value of each element. - * * @param o object to be compared for equality with this collection * @return true is the specified object is equal to this collection */ diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashMultiCollection.java b/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashMultiCollection.java deleted file mode 100644 index b5ae8f9793a46..0000000000000 --- a/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashMultiCollection.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.kafka.common.utils; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; - -/** - * A memory-efficient hash multiset which tracks the order of insertion of elements. - * See org.apache.kafka.common.utils.ImplicitLinkedHashCollection for implementation details. - * - * This class is a multi-set because it allows multiple elements to be inserted that are - * equal to each other. - * - * We use reference equality when adding elements to the set. A new element A can - * be added if there is no existing element B such that A == B. If an element B - * exists such that A.equals(B), A will still be added. - * - * When deleting an element A from the set, we will try to delete the element B such - * that A == B. If no such element can be found, we will try to delete an element B - * such that A.equals(B). - * - * contains() and find() are unchanged from the base class-- they will look for element - * based on object equality, not reference equality. - * - * This multiset does not allow null elements. It does not have internal synchronization. - */ -public class ImplicitLinkedHashMultiCollection - extends ImplicitLinkedHashCollection { - public ImplicitLinkedHashMultiCollection() { - super(0); - } - - public ImplicitLinkedHashMultiCollection(int expectedNumElements) { - super(expectedNumElements); - } - - public ImplicitLinkedHashMultiCollection(Iterator iter) { - super(iter); - } - - - /** - * Adds a new element to the appropriate place in the elements array. - * - * @param newElement The new element to add. - * @param addElements The elements array. - * @return The index at which the element was inserted, or INVALID_INDEX - * if the element could not be inserted. - */ - @Override - int addInternal(Element newElement, Element[] addElements) { - int slot = slot(addElements, newElement); - for (int seen = 0; seen < addElements.length; seen++) { - Element element = addElements[slot]; - if (element == null) { - addElements[slot] = newElement; - return slot; - } - if (element == newElement) { - return INVALID_INDEX; - } - slot = (slot + 1) % addElements.length; - } - throw new RuntimeException("Not enough hash table slots to add a new element."); - } - - /** - * Find an element matching an example element. - * - * @param key The element to match. - * - * @return The match index, or INVALID_INDEX if no match was found. - */ - @Override - int findElementToRemove(Object key) { - if (key == null || size() == 0) { - return INVALID_INDEX; - } - int slot = slot(elements, key); - int bestSlot = INVALID_INDEX; - for (int seen = 0; seen < elements.length; seen++) { - Element element = elements[slot]; - if (element == null) { - return bestSlot; - } - if (key == element) { - return slot; - } else if (key.equals(element)) { - bestSlot = slot; - } - slot = (slot + 1) % elements.length; - } - return INVALID_INDEX; - } - - /** - * Returns all of the elements e in the collection such that - * key.equals(e) and key.hashCode() == e.hashCode(). - * - * @param key The element to match. - * - * @return All of the matching elements. - */ - final public List findAll(E key) { - if (key == null || size() == 0) { - return Collections.emptyList(); - } - ArrayList results = new ArrayList<>(); - int slot = slot(elements, key); - for (int seen = 0; seen < elements.length; seen++) { - Element element = elements[slot]; - if (element == null) { - break; - } - if (key.equals(element)) { - @SuppressWarnings("unchecked") - E result = (E) elements[slot]; - results.add(result); - } - slot = (slot + 1) % elements.length; - } - return results; - } -} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashCollectionTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashCollectionTest.java index 389c24e456b96..6534b53f43a50 100644 --- a/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashCollectionTest.java +++ b/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashCollectionTest.java @@ -16,21 +16,24 @@ */ package org.apache.kafka.common.utils; +import org.apache.kafka.common.utils.ImplicitLinkedHashCollection.Element; +import org.apache.kafka.common.utils.ImplicitLinkedHashCollection.ObjectEqualityComparator; +import org.apache.kafka.common.utils.ImplicitLinkedHashCollection.ReferenceEqualityComparator; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; import java.util.ArrayList; -import java.util.Collection; +import java.util.Arrays; import java.util.Iterator; -import java.util.LinkedHashSet; +import java.util.LinkedList; import java.util.List; import java.util.ListIterator; import java.util.Random; import java.util.Set; -import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertFalse; @@ -44,13 +47,20 @@ public class ImplicitLinkedHashCollectionTest { @Rule final public Timeout globalTimeout = Timeout.millis(120000); - final static class TestElement implements ImplicitLinkedHashCollection.Element { + final static class TestElement implements Element { private int prev = ImplicitLinkedHashCollection.INVALID_INDEX; private int next = ImplicitLinkedHashCollection.INVALID_INDEX; private final int val; + private final boolean wasDuplicated; TestElement(int val) { this.val = val; + this.wasDuplicated = false; + } + + private TestElement(int val, boolean wasDuplicated) { + this.val = val; + this.wasDuplicated = wasDuplicated; } @Override @@ -73,6 +83,11 @@ public void setNext(int next) { this.next = next; } + @Override + public Element duplicate() { + return new TestElement(val, true); + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -94,8 +109,8 @@ public int hashCode() { @Test public void testNullForbidden() { - ImplicitLinkedHashMultiCollection multiColl = new ImplicitLinkedHashMultiCollection<>(); - assertFalse(multiColl.add(null)); + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + assertFalse(coll.add(null)); } @Test @@ -105,7 +120,7 @@ public void testInsertDelete() { TestElement second = new TestElement(2); assertTrue(coll.add(second)); assertTrue(coll.add(new TestElement(3))); - assertFalse(coll.add(new TestElement(3))); + assertTrue(coll.addOrReplace(new TestElement(3), ObjectEqualityComparator.INSTANCE)); assertEquals(3, coll.size()); assertTrue(coll.contains(new TestElement(1))); assertFalse(coll.contains(new TestElement(4))); @@ -118,6 +133,28 @@ public void testInsertDelete() { assertEquals(0, coll.size()); } + @Test + public void testInsertDeleteDuplicates() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(100); + TestElement e1 = new TestElement(1); + TestElement e2 = new TestElement(1); + TestElement e3 = new TestElement(2); + assertTrue(coll.add(e1)); + assertTrue(coll.add(e2)); + assertTrue(coll.add(e3)); + assertTrue(coll.add(e3)); + assertEquals(4, coll.size()); + List ones = coll.findAll(e1); + assertEquals(Arrays.asList(e1, e1), ones); + assertTrue(e1 == coll.find(e1, ReferenceEqualityComparator.INSTANCE)); + assertTrue(e2 == coll.find(e2, ReferenceEqualityComparator.INSTANCE)); + List threes = coll.findAll(e3); + assertEquals(2, threes.size()); + assertEquals(1, threes.stream().filter(e -> e == e3).count()); + assertEquals(2, threes.stream().filter(e -> e.equals(e3)).count()); + coll.remove(e2); + } + static void expectTraversal(Iterator iterator, Integer... sequence) { int i = 0; while (iterator.hasNext()) { @@ -132,19 +169,19 @@ static void expectTraversal(Iterator iterator, Integer... sequence) sequence.length + " were expected.", i == sequence.length); } - static void expectTraversal(Iterator iter, Iterator expectedIter) { + static void expectTraversal(Iterator iter, Iterator expectedIter) { int i = 0; while (iter.hasNext()) { TestElement element = iter.next(); Assert.assertTrue("Iterator yieled " + (i + 1) + " elements, but only " + - i + " were expected.", expectedIter.hasNext()); - Integer expected = expectedIter.next(); - Assert.assertEquals("Iterator value number " + (i + 1) + " was incorrect.", - expected.intValue(), element.val); + i + " were expected.", expectedIter.hasNext()); + TestElement expected = expectedIter.next(); + assertTrue("Iterator value number " + (i + 1) + " was incorrect.", + expected == element); i = i + 1; } Assert.assertFalse("Iterator yieled " + i + " elements, but at least " + - (i + 1) + " were expected.", expectedIter.hasNext()); + (i + 1) + " were expected.", expectedIter.hasNext()); } @Test @@ -178,6 +215,49 @@ public void testTraversal() { assertTrue(coll.isEmpty()); } + @Test + public void testRetainAll() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + assertTrue(coll.addAll( + Arrays.asList(new TestElement(2), new TestElement(3), new TestElement(1)))); + assertTrue(coll.addAll( + Arrays.asList(new TestElement(3), new TestElement(2), new TestElement(1)))); + assertFalse(coll.retainAll( + Arrays.asList(new TestElement(1), new TestElement(2), new TestElement(3)))); + assertEquals(6, coll.size()); + assertTrue(coll.containsAll( + Arrays.asList(new TestElement(1), new TestElement(3)))); + assertFalse(coll.containsAll( + Arrays.asList(new TestElement(1), new TestElement(3), new TestElement(4)))); + assertTrue(coll.containsAll( + Arrays.asList(new TestElement(1), new TestElement(2), new TestElement(3)))); + assertTrue(coll.retainAll( + Arrays.asList(new TestElement(1), new TestElement(3)))); + assertEquals(4, coll.size()); + assertTrue(coll.containsAll( + Arrays.asList(new TestElement(1), new TestElement(3)))); + assertFalse(coll.containsAll( + Arrays.asList(new TestElement(1), new TestElement(2), new TestElement(3)))); + } + + @SuppressWarnings("unchecked") + @Test + public void testToArray() { + TestElement[] elements = new TestElement[] { + new TestElement(1), + new TestElement(5), + new TestElement(2) + }; + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + assertTrue(coll.addAll(Arrays.asList(elements))); + Object[] elements2 = coll.toArray(); + assertArrayEquals(elements, elements2); + TestElement[] elements3 = new TestElement[2]; + assertArrayEquals(elements, coll.toArray(elements3)); + TestElement[] elements4 = new TestElement[3]; + assertArrayEquals(elements, coll.toArray(elements4)); + } + @Test public void testSetViewGet() { ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); @@ -481,13 +561,23 @@ public void testEnlargement() { @Test public void testManyInsertsAndDeletes() { Random random = new Random(123); - LinkedHashSet existing = new LinkedHashSet<>(); + LinkedList existing = new LinkedList<>(); ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); for (int i = 0; i < 100; i++) { - addRandomElement(random, existing, coll); - addRandomElement(random, existing, coll); - addRandomElement(random, existing, coll); - removeRandomElement(random, existing, coll); + for (int j = 0; j < 4; j++) { + TestElement testElement = new TestElement(random.nextInt()); + coll.add(testElement); + existing.add(testElement); + } + int elementToRemove = random.nextInt(coll.size()); + Iterator iter1 = coll.iterator(); + Iterator iter2 = existing.iterator(); + for (int j = 0; j <= elementToRemove; j++) { + iter1.next(); + iter2.next(); + } + iter1.remove(); + iter2.remove(); expectTraversal(coll.iterator(), existing.iterator()); } } @@ -509,9 +599,12 @@ public void testEquals() { coll3.add(new TestElement(3)); coll3.add(new TestElement(2)); - assertEquals(coll1, coll2); - assertNotEquals(coll1, coll3); - assertNotEquals(coll2, coll3); + assertTrue(coll1.toString() + " was not equal to " + coll2.toString(), + coll1.equals(coll2)); + assertFalse(coll1.toString() + " was equal to " + coll3.toString(), + coll1.equals(coll3)); + assertFalse(coll2.toString() + " was equal to " + coll3.toString(), + coll2.equals(coll3)); } @Test @@ -522,25 +615,28 @@ public void testFindContainsRemoveOnEmptyCollection() { assertFalse(coll.remove(new TestElement(2))); } - private void addRandomElement(Random random, LinkedHashSet existing, - ImplicitLinkedHashCollection set) { - int next; - do { - next = random.nextInt(); - } while (existing.contains(next)); - existing.add(next); - set.add(new TestElement(next)); + @Test + public void testFindFindAllContainsRemoveOnEmptyCollection() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + assertNull(coll.find(new TestElement(2))); + assertFalse(coll.contains(new TestElement(2))); + assertFalse(coll.remove(new TestElement(2))); + assertTrue(coll.findAll(new TestElement(2)).isEmpty()); } - @SuppressWarnings("unlikely-arg-type") - private void removeRandomElement(Random random, Collection existing, - ImplicitLinkedHashCollection coll) { - int removeIdx = random.nextInt(existing.size()); - Iterator iter = existing.iterator(); - Integer element = null; - for (int i = 0; i <= removeIdx; i++) { - element = iter.next(); - } - existing.remove(new TestElement(element)); + @Test + public void testReinsertExistingElement() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + TestElement testElement = new TestElement(1); + assertTrue(coll.add(testElement)); + assertTrue(coll.add(testElement)); + Iterator iter = coll.iterator(); + assertTrue(iter.hasNext()); + assertEquals(testElement, iter.next()); + assertFalse(testElement.wasDuplicated); + assertTrue(iter.hasNext()); + TestElement clone = iter.next(); + assertTrue(clone.wasDuplicated); + assertEquals(1, clone.val); } } diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashMultiCollectionTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashMultiCollectionTest.java deleted file mode 100644 index ad87b55c493c4..0000000000000 --- a/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashMultiCollectionTest.java +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.kafka.common.utils; - -import org.apache.kafka.common.utils.ImplicitLinkedHashCollectionTest.TestElement; -import org.junit.Assert; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.Timeout; - -import java.util.Iterator; -import java.util.LinkedList; -import java.util.Random; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; - -/** - * A unit test for ImplicitLinkedHashMultiCollection. - */ -public class ImplicitLinkedHashMultiCollectionTest { - @Rule - final public Timeout globalTimeout = Timeout.millis(120000); - - @Test - public void testNullForbidden() { - ImplicitLinkedHashMultiCollection multiSet = new ImplicitLinkedHashMultiCollection<>(); - assertFalse(multiSet.add(null)); - } - - @Test - public void testFindFindAllContainsRemoveOnEmptyCollection() { - ImplicitLinkedHashMultiCollection coll = new ImplicitLinkedHashMultiCollection<>(); - assertNull(coll.find(new TestElement(2))); - assertFalse(coll.contains(new TestElement(2))); - assertFalse(coll.remove(new TestElement(2))); - assertTrue(coll.findAll(new TestElement(2)).isEmpty()); - } - - @Test - public void testInsertDelete() { - ImplicitLinkedHashMultiCollection multiSet = new ImplicitLinkedHashMultiCollection<>(100); - TestElement e1 = new TestElement(1); - TestElement e2 = new TestElement(1); - TestElement e3 = new TestElement(2); - multiSet.mustAdd(e1); - multiSet.mustAdd(e2); - multiSet.mustAdd(e3); - assertFalse(multiSet.add(e3)); - assertEquals(3, multiSet.size()); - expectExactTraversal(multiSet.findAll(e1).iterator(), e1, e2); - expectExactTraversal(multiSet.findAll(e3).iterator(), e3); - multiSet.remove(e2); - expectExactTraversal(multiSet.findAll(e1).iterator(), e1); - assertTrue(multiSet.contains(e2)); - } - - @Test - public void testTraversal() { - ImplicitLinkedHashMultiCollection multiSet = new ImplicitLinkedHashMultiCollection<>(); - expectExactTraversal(multiSet.iterator()); - TestElement e1 = new TestElement(1); - TestElement e2 = new TestElement(1); - TestElement e3 = new TestElement(2); - assertTrue(multiSet.add(e1)); - assertTrue(multiSet.add(e2)); - assertTrue(multiSet.add(e3)); - expectExactTraversal(multiSet.iterator(), e1, e2, e3); - assertTrue(multiSet.remove(e2)); - expectExactTraversal(multiSet.iterator(), e1, e3); - assertTrue(multiSet.remove(e1)); - expectExactTraversal(multiSet.iterator(), e3); - } - - static void expectExactTraversal(Iterator iterator, TestElement... sequence) { - int i = 0; - while (iterator.hasNext()) { - TestElement element = iterator.next(); - assertTrue("Iterator yieled " + (i + 1) + " elements, but only " + - sequence.length + " were expected.", i < sequence.length); - if (sequence[i] != element) { - fail("Iterator value number " + (i + 1) + " was incorrect."); - } - i = i + 1; - } - assertTrue("Iterator yieled " + (i + 1) + " elements, but " + - sequence.length + " were expected.", i == sequence.length); - } - - @Test - public void testEnlargement() { - ImplicitLinkedHashMultiCollection multiSet = new ImplicitLinkedHashMultiCollection<>(5); - assertEquals(11, multiSet.numSlots()); - TestElement[] testElements = { - new TestElement(100), - new TestElement(101), - new TestElement(102), - new TestElement(100), - new TestElement(101), - new TestElement(105) - }; - for (int i = 0; i < testElements.length; i++) { - assertTrue(multiSet.add(testElements[i])); - } - for (int i = 0; i < testElements.length; i++) { - assertFalse(multiSet.add(testElements[i])); - } - assertEquals(23, multiSet.numSlots()); - assertEquals(testElements.length, multiSet.size()); - expectExactTraversal(multiSet.iterator(), testElements); - multiSet.remove(testElements[1]); - assertEquals(23, multiSet.numSlots()); - assertEquals(5, multiSet.size()); - expectExactTraversal(multiSet.iterator(), - testElements[0], testElements[2], testElements[3], testElements[4], testElements[5]); - } - - @Test - public void testManyInsertsAndDeletes() { - Random random = new Random(123); - LinkedList existing = new LinkedList<>(); - ImplicitLinkedHashMultiCollection multiSet = new ImplicitLinkedHashMultiCollection<>(); - for (int i = 0; i < 100; i++) { - for (int j = 0; j < 4; j++) { - TestElement testElement = new TestElement(random.nextInt()); - multiSet.mustAdd(testElement); - existing.add(testElement); - } - int elementToRemove = random.nextInt(multiSet.size()); - Iterator iter1 = multiSet.iterator(); - Iterator iter2 = existing.iterator(); - for (int j = 0; j <= elementToRemove; j++) { - iter1.next(); - iter2.next(); - } - iter1.remove(); - iter2.remove(); - expectTraversal(multiSet.iterator(), existing.iterator()); - } - } - - void expectTraversal(Iterator iter, Iterator expectedIter) { - int i = 0; - while (iter.hasNext()) { - TestElement element = iter.next(); - Assert.assertTrue("Iterator yieled " + (i + 1) + " elements, but only " + - i + " were expected.", expectedIter.hasNext()); - TestElement expected = expectedIter.next(); - assertTrue("Iterator value number " + (i + 1) + " was incorrect.", - expected == element); - i = i + 1; - } - Assert.assertFalse("Iterator yieled " + i + " elements, but at least " + - (i + 1) + " were expected.", expectedIter.hasNext()); - } -} diff --git a/core/src/main/scala/kafka/server/FetchSession.scala b/core/src/main/scala/kafka/server/FetchSession.scala index 8a3cd96fb2f81..eaa2214e07da3 100644 --- a/core/src/main/scala/kafka/server/FetchSession.scala +++ b/core/src/main/scala/kafka/server/FetchSession.scala @@ -29,6 +29,7 @@ import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.Records import org.apache.kafka.common.requests.FetchMetadata.{FINAL_EPOCH, INITIAL_EPOCH, INVALID_SESSION_ID} import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, FetchMetadata => JFetchMetadata} +import org.apache.kafka.common.utils.ImplicitLinkedHashCollection.ObjectEqualityComparator import org.apache.kafka.common.utils.{ImplicitLinkedHashCollection, Time, Utils} import scala.math.Ordered.orderingToOrdered @@ -181,6 +182,11 @@ class CachedPartition(val topic: String, ", localLogStartOffset=" + localLogStartOffset + ")" } + + // This operation is only used by ImplicitLinkedHashCollection when we + // try to insert the same object into the collection more than once. + // Since we do not plan to do that, we don't need to implement this. + override def duplicate(): CachedPartition = throw new UnsupportedOperationException } /** @@ -245,7 +251,7 @@ case class FetchSession(val id: Int, val newCachedPart = new CachedPartition(topicPart, reqData) val cachedPart = partitionMap.find(newCachedPart) if (cachedPart == null) { - partitionMap.mustAdd(newCachedPart) + partitionMap.add(newCachedPart) added.add(topicPart) } else { cachedPart.updateRequestParams(reqData) @@ -378,7 +384,8 @@ class FullFetchContext(private val time: Time, val part = entry.getKey val respData = entry.getValue val reqData = fetchData.get(part) - cachedPartitions.mustAdd(new CachedPartition(part, reqData, respData)) + cachedPartitions.addOrReplace(new CachedPartition(part, reqData, respData), + ObjectEqualityComparator.INSTANCE) }) cachedPartitions } @@ -431,7 +438,7 @@ class IncrementalFetchContext(private val time: Time, nextElement = element if (updateFetchContextAndRemoveUnselected) { session.partitionMap.remove(cachedPart) - session.partitionMap.mustAdd(cachedPart) + session.partitionMap.add(cachedPart) } } else { if (updateFetchContextAndRemoveUnselected) { diff --git a/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java b/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java index 263da1231def0..3c2b0d608221b 100644 --- a/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java +++ b/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java @@ -100,7 +100,7 @@ private void generateClass(Optional topLevelMessageSpec, buffer.printf("%n"); generateClassSize(className, struct, parentVersions); buffer.printf("%n"); - generateClassEquals(className, struct, isSetElement); + generateClassEquals(className, struct, false); buffer.printf("%n"); generateClassHashCode(struct, isSetElement); buffer.printf("%n"); @@ -109,11 +109,17 @@ private void generateClass(Optional topLevelMessageSpec, buffer.printf("%n"); generateUnknownTaggedFieldsAccessor(struct); generateFieldMutators(struct, className, isSetElement); + generateClassCopy(struct, className); + generateClassDuplicate(className); if (!isTopLevel) { buffer.decrementIndent(); buffer.printf("}%n"); } + if (isSetElement) { + buffer.printf("%n"); + generateClassComparator(className, struct); + } generateSubclasses(className, struct, parentVersions, isSetElement); if (isTopLevel) { for (Iterator iter = structRegistry.commonStructs(); iter.hasNext(); ) { @@ -139,8 +145,8 @@ private void generateClassHeader(String className, boolean isTopLevel, headerGenerator.addImport(MessageGenerator.MESSAGE_CLASS); } if (isSetElement) { - headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS); - implementedInterfaces.add("ImplicitLinkedHashMultiCollection.Element"); + headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_COLLECTION_CLASS); + implementedInterfaces.add("ImplicitLinkedHashCollection.Element"); } Set classModifiers = new HashSet<>(); classModifiers.add("public"); @@ -173,8 +179,8 @@ private void generateSubclasses(String className, StructSpec struct, private void generateHashSet(String className, StructSpec struct) { buffer.printf("%n"); - headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS); - buffer.printf("public static class %s extends ImplicitLinkedHashMultiCollection<%s> {%n", + headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_COLLECTION_CLASS); + buffer.printf("public static class %s extends ImplicitLinkedHashCollection<%s> {%n", collectionType(className), className); buffer.incrementIndent(); generateHashSetZeroArgConstructor(className); @@ -220,7 +226,7 @@ private void generateHashSetFindMethod(String className, StructSpec struct) { commaSeparatedHashSetFieldAndTypes(struct)); buffer.incrementIndent(); generateKeyElement(className, struct); - headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS); + headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_COLLECTION_CLASS); buffer.printf("return find(_key);%n"); buffer.decrementIndent(); buffer.printf("}%n"); @@ -233,8 +239,8 @@ private void generateHashSetFindAllMethod(String className, StructSpec struct) { commaSeparatedHashSetFieldAndTypes(struct)); buffer.incrementIndent(); generateKeyElement(className, struct); - headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS); - buffer.printf("return findAll(_key);%n"); + headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_COLLECTION_CLASS); + buffer.printf("return findAll(_key, %sComparator.INSTANCE);%n", className); buffer.decrementIndent(); buffer.printf("}%n"); buffer.printf("%n"); @@ -327,6 +333,97 @@ private void generateFieldMutators(StructSpec struct, String className, } } + private void generateClassCopy(StructSpec struct, String className) { + buffer.printf("%n"); + buffer.printf("public void copy(%s other) {%n", className); + buffer.incrementIndent(); + for (FieldSpec field : struct.fields()) { + generateClassFieldCopy(field); + } + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateClassFieldCopy(FieldSpec field) { + if (field.type() instanceof FieldType.ArrayType) { + FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); + FieldType elementType = arrayType.elementType(); + headerGenerator.addImport(MessageGenerator.ITERATOR_CLASS); + buffer.printf("for (Iterator<%s> _i = other.%s.iterator(); _i.hasNext(); ) {%n", + getBoxedJavaType(elementType), + field.camelCaseName()); + buffer.incrementIndent(); + buffer.printf("this.%s.add(%s);%n", + field.camelCaseName(), + generateCopyExpression(elementType, + field.zeroCopy(), + !field.nullableVersions().empty(), + "_i.next()")); + buffer.decrementIndent(); + buffer.printf("}%n"); + } else { + buffer.printf("this.%s = %s;%n", + field.camelCaseName(), + generateCopyExpression(field.type(), + field.zeroCopy(), + !field.nullableVersions().empty(), + String.format("other.%s", field.camelCaseName()))); + } + } + + private String generateCopyExpression(FieldType type, + boolean zeroCopy, + boolean nullable, + String name) { + if (type instanceof FieldType.BoolFieldType || + type instanceof FieldType.Int8FieldType || + type instanceof FieldType.Int16FieldType || + type instanceof FieldType.Int32FieldType || + type instanceof FieldType.Int64FieldType) { + // Primitives should be be copied by value. + return name; + } else if (type instanceof FieldType.UUIDFieldType || + type instanceof FieldType.StringFieldType) { + // Immutable types can be copied by reference. + return name; + } else if (type.isBytes()) { + if (zeroCopy) { + if (nullable) { + return String.format("(%s == null) ? null : %s.duplicate()", + name, name); + } else { + return String.format("%s.duplicate()", + name); + } + } else { + headerGenerator.addImport(MessageGenerator.ARRAYS_CLASS); + if (nullable) { + return String.format("(%s == null) ? null : " + + "Arrays.copyOf(%s, %s.length)", + name, name, name); + } else { + return String.format("Arrays.copyOf(%s, %s.length)", + name, name); + } + } + } else if (type.isStruct()) { + return String.format("%s.duplicate()", name); + } else { + throw new RuntimeException("Unsupported field type " + type); + } + } + + private void generateClassDuplicate(String className) { + buffer.printf("%n"); + buffer.printf("public %s duplicate() {%n", className); + buffer.incrementIndent(); + buffer.printf("%s _duplicate = new %s();%n", className, className); + buffer.printf("_duplicate.copy(this);%n"); + buffer.printf("return _duplicate;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + private static String collectionType(String baseType) { return baseType + "Collection"; } @@ -359,7 +456,7 @@ private String fieldAbstractJavaType(FieldSpec field) { } else if (field.type().isArray()) { FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); if (structRegistry.isStructArrayWithKeys(field)) { - headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS); + headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_COLLECTION_CLASS); return collectionType(arrayType.elementType().toString()); } else { headerGenerator.addImport(MessageGenerator.LIST_CLASS); @@ -374,7 +471,7 @@ private String fieldConcreteJavaType(FieldSpec field) { if (field.type().isArray()) { FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); if (structRegistry.isStructArrayWithKeys(field)) { - headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS); + headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_COLLECTION_CLASS); return collectionType(arrayType.elementType().toString()); } else { headerGenerator.addImport(MessageGenerator.ARRAYLIST_CLASS); @@ -625,7 +722,7 @@ private void generateVariableLengthReader(Versions fieldFlexibleVersions, } else if (type.isArray()) { FieldType.ArrayType arrayType = (FieldType.ArrayType) type; if (isStructArrayWithKeys) { - headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS); + headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_COLLECTION_CLASS); buffer.printf("%s newCollection = new %s(%s);%n", collectionType(arrayType.elementType().toString()), collectionType(arrayType.elementType().toString()), lengthVar); @@ -1631,7 +1728,7 @@ private void generateStringToBytes(String name) { buffer.printf("_cache.cacheSerializedValue(%s, _stringBytes);%n", name); } - private void generateClassEquals(String className, StructSpec struct, boolean onlyMapKeys) { + private void generateClassEquals(String className, StructSpec struct, boolean isKeyEquals) { buffer.printf("@Override%n"); buffer.printf("public boolean equals(Object obj) {%n"); buffer.incrementIndent(); @@ -1639,55 +1736,79 @@ private void generateClassEquals(String className, StructSpec struct, boolean on if (!struct.fields().isEmpty()) { buffer.printf("%s other = (%s) obj;%n", className, className); for (FieldSpec field : struct.fields()) { - if ((!onlyMapKeys) || field.mapKey()) { - generateFieldEquals(field); + generateFieldEquals(field, "this", "other"); + } + } + buffer.printf("return true;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateClassComparator(String className, StructSpec struct) { + headerGenerator.addStaticImport(MessageGenerator.IMPLICIT_LINKED_HASH_COLLECTION_COMPARATOR_CLASS); + buffer.printf("public static class %sComparator implements Comparator<%s> {%n", + className, className); + buffer.incrementIndent(); + buffer.printf("public static final %sComparator INSTANCE = new %sComparator();%n", + className, className); + buffer.printf("%n"); + buffer.printf("@Override%n"); + buffer.printf("public boolean compare(%s x, %s y) {%n", + className, className); + buffer.incrementIndent(); + if (!struct.fields().isEmpty()) { + for (FieldSpec field : struct.fields()) { + if (field.mapKey()) { + generateFieldEquals(field, "x", "y"); } } } buffer.printf("return true;%n"); buffer.decrementIndent(); buffer.printf("}%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); } - private void generateFieldEquals(FieldSpec field) { + private void generateFieldEquals(FieldSpec field, String left, String right) { if (field.type() instanceof FieldType.UUIDFieldType) { - buffer.printf("if (!this.%s.equals(other.%s)) return false;%n", - field.camelCaseName(), field.camelCaseName()); + buffer.printf("if (!%s.%s.equals(%s.%s)) return false;%n", + left, field.camelCaseName(), right, field.camelCaseName()); } else if (field.type().isString() || field.type().isArray() || field.type().isStruct()) { - buffer.printf("if (this.%s == null) {%n", field.camelCaseName()); + buffer.printf("if (%s.%s == null) {%n", left, field.camelCaseName()); buffer.incrementIndent(); - buffer.printf("if (other.%s != null) return false;%n", field.camelCaseName()); + buffer.printf("if (%s.%s != null) return false;%n", right, field.camelCaseName()); buffer.decrementIndent(); buffer.printf("} else {%n"); buffer.incrementIndent(); - buffer.printf("if (!this.%s.equals(other.%s)) return false;%n", - field.camelCaseName(), field.camelCaseName()); + buffer.printf("if (!%s.%s.equals(%s.%s)) return false;%n", + left, field.camelCaseName(), right, field.camelCaseName()); buffer.decrementIndent(); buffer.printf("}%n"); } else if (field.type().isBytes()) { if (field.zeroCopy()) { headerGenerator.addImport(MessageGenerator.OBJECTS_CLASS); - buffer.printf("if (!Objects.equals(this.%s, other.%s)) return false;%n", - field.camelCaseName(), field.camelCaseName()); + buffer.printf("if (!Objects.equals(%s.%s, %s.%s)) return false;%n", + left, field.camelCaseName(), right, field.camelCaseName()); } else { // Arrays#equals handles nulls. headerGenerator.addImport(MessageGenerator.ARRAYS_CLASS); - buffer.printf("if (!Arrays.equals(this.%s, other.%s)) return false;%n", - field.camelCaseName(), field.camelCaseName()); + buffer.printf("if (!Arrays.equals(%s.%s, %s.%s)) return false;%n", + left, field.camelCaseName(), right, field.camelCaseName()); } } else { - buffer.printf("if (%s != other.%s) return false;%n", - field.camelCaseName(), field.camelCaseName()); + buffer.printf("if (%s.%s != %s.%s) return false;%n", + left, field.camelCaseName(), right, field.camelCaseName()); } } - private void generateClassHashCode(StructSpec struct, boolean onlyMapKeys) { + private void generateClassHashCode(StructSpec struct, boolean isKeyHash) { buffer.printf("@Override%n"); buffer.printf("public int hashCode() {%n"); buffer.incrementIndent(); buffer.printf("int hashCode = 0;%n"); for (FieldSpec field : struct.fields()) { - if ((!onlyMapKeys) || field.mapKey()) { + if ((!isKeyHash) || field.mapKey()) { generateFieldHashCode(field); } } diff --git a/generator/src/main/java/org/apache/kafka/message/MessageGenerator.java b/generator/src/main/java/org/apache/kafka/message/MessageGenerator.java index ddb3250533eaa..526d77fc4c83d 100644 --- a/generator/src/main/java/org/apache/kafka/message/MessageGenerator.java +++ b/generator/src/main/java/org/apache/kafka/message/MessageGenerator.java @@ -61,8 +61,11 @@ public final class MessageGenerator { static final String ARRAYLIST_CLASS = "java.util.ArrayList"; - static final String IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS = - "org.apache.kafka.common.utils.ImplicitLinkedHashMultiCollection"; + static final String IMPLICIT_LINKED_HASH_COLLECTION_COMPARATOR_CLASS = + "org.apache.kafka.common.utils.ImplicitLinkedHashCollection.Comparator"; + + static final String IMPLICIT_LINKED_HASH_COLLECTION_CLASS = + "org.apache.kafka.common.utils.ImplicitLinkedHashCollection"; static final String UNSUPPORTED_VERSION_EXCEPTION_CLASS = "org.apache.kafka.common.errors.UnsupportedVersionException";