diff --git a/clients/src/main/java/org/apache/kafka/common/header/internals/RecordHeaders.java b/clients/src/main/java/org/apache/kafka/common/header/internals/RecordHeaders.java index 1277408270800..e6f05c1ab6887 100644 --- a/clients/src/main/java/org/apache/kafka/common/header/internals/RecordHeaders.java +++ b/clients/src/main/java/org/apache/kafka/common/header/internals/RecordHeaders.java @@ -23,7 +23,6 @@ import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.Objects; @@ -38,11 +37,7 @@ public RecordHeaders() { } public RecordHeaders(Header[] headers) { - if (headers == null) { - this.headers = new ArrayList<>(); - } else { - this.headers = new ArrayList<>(Arrays.asList(headers)); - } + this(headers == null ? null : Arrays.asList(headers)); } public RecordHeaders(Iterable
headers) { @@ -51,12 +46,12 @@ public RecordHeaders(Iterable
headers) { this.headers = new ArrayList<>(); } else if (headers instanceof RecordHeaders) { this.headers = new ArrayList<>(((RecordHeaders) headers).headers); - } else if (headers instanceof Collection) { - this.headers = new ArrayList<>((Collection
) headers); } else { this.headers = new ArrayList<>(); - for (Header header : headers) + for (Header header : headers) { + Objects.requireNonNull(header, "Header cannot be null."); this.headers.add(header); + } } } diff --git a/clients/src/test/java/org/apache/kafka/common/header/internals/RecordHeadersTest.java b/clients/src/test/java/org/apache/kafka/common/header/internals/RecordHeadersTest.java index 5b9f95ea91f18..8a6379992051f 100644 --- a/clients/src/test/java/org/apache/kafka/common/header/internals/RecordHeadersTest.java +++ b/clients/src/test/java/org/apache/kafka/common/header/internals/RecordHeadersTest.java @@ -27,6 +27,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -206,9 +207,15 @@ public void testNew() throws IOException { assertEquals(2, getCount(newHeaders)); } - @Test(expected = NullPointerException.class) + @Test public void shouldThrowNpeWhenAddingNullHeader() { - new RecordHeaders().add(null); + final RecordHeaders recordHeaders = new RecordHeaders(); + assertThrows(NullPointerException.class, () -> recordHeaders.add(null)); + } + + @Test + public void shouldThrowNpeWhenAddingCollectionWithNullHeader() { + assertThrows(NullPointerException.class, () -> new RecordHeaders(new Header[1])); } private int getCount(Headers headers) { diff --git a/connect/api/src/main/java/org/apache/kafka/connect/header/ConnectHeaders.java b/connect/api/src/main/java/org/apache/kafka/connect/header/ConnectHeaders.java index e9c8bce2734af..5c37ddc5e58b4 100644 --- a/connect/api/src/main/java/org/apache/kafka/connect/header/ConnectHeaders.java +++ b/connect/api/src/main/java/org/apache/kafka/connect/header/ConnectHeaders.java @@ -63,6 +63,7 @@ public ConnectHeaders(Iterable
original) { } else { headers = new LinkedList<>(); for (Header header : original) { + Objects.requireNonNull(header, "Unable to add a null header."); headers.add(header); } } @@ -75,7 +76,7 @@ public int size() { @Override public boolean isEmpty() { - return headers == null ? true : headers.isEmpty(); + return headers == null || headers.isEmpty(); } @Override diff --git a/connect/api/src/test/java/org/apache/kafka/connect/header/ConnectHeadersTest.java b/connect/api/src/test/java/org/apache/kafka/connect/header/ConnectHeadersTest.java index 72418ba47ffa7..f4f11d004d6e8 100644 --- a/connect/api/src/test/java/org/apache/kafka/connect/header/ConnectHeadersTest.java +++ b/connect/api/src/test/java/org/apache/kafka/connect/header/ConnectHeadersTest.java @@ -34,6 +34,7 @@ import java.math.BigDecimal; import java.math.RoundingMode; import java.util.ArrayList; +import java.util.Arrays; import java.util.Calendar; import java.util.Collections; import java.util.GregorianCalendar; @@ -47,6 +48,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -498,6 +500,18 @@ public void shouldDuplicateAndAlwaysReturnEquivalentButDifferentObject() { assertNotSame(headers, headers.duplicate()); } + @Test + public void shouldNotAllowToAddNullHeader() { + final ConnectHeaders headers = new ConnectHeaders(); + assertThrows(NullPointerException.class, () -> headers.add(null)); + } + + @Test + public void shouldThrowNpeWhenAddingCollectionWithNullHeader() { + final Iterable
header = Arrays.asList(new ConnectHeader[1]); + assertThrows(NullPointerException.class, () -> new ConnectHeaders(header)); + } + protected void assertSchemaMatches(Schema schema, Object value) { headers.checkSchemaMatches(new SchemaAndValue(schema.schema(), value)); }