diff --git a/src/main/java/org/apache/commons/io/channels/ByteArraySeekableByteChannel.java b/src/main/java/org/apache/commons/io/channels/ByteArraySeekableByteChannel.java index 4135a815e83..2cd8009203f 100644 --- a/src/main/java/org/apache/commons/io/channels/ByteArraySeekableByteChannel.java +++ b/src/main/java/org/apache/commons/io/channels/ByteArraySeekableByteChannel.java @@ -68,7 +68,7 @@ public static ByteArraySeekableByteChannel wrap(final byte[] bytes) { private byte[] data; private volatile boolean closed; - private int position; + private long position; private int size; private final ReentrantLock lock = new ReentrantLock(); @@ -126,11 +126,10 @@ private void checkOpen() throws ClosedChannelException { } } - private int checkRange(final long newSize, final String method) { - if (newSize < 0L || newSize > IOUtils.SOFT_MAX_ARRAY_LENGTH) { - throw new IllegalArgumentException(String.format("%s must be in range [0..%,d]: %,d", method, IOUtils.SOFT_MAX_ARRAY_LENGTH, newSize)); + private void checkRange(final long newSize, final String method) { + if (newSize < 0L) { + throw new IllegalArgumentException(String.format("%s must be positive: %,d", method, newSize)); } - return (int) newSize; } @Override @@ -166,10 +165,10 @@ public long position() throws ClosedChannelException { @Override public SeekableByteChannel position(final long newPosition) throws IOException { checkOpen(); - final int intPos = checkRange(newPosition, "position()"); + checkRange(newPosition, "position()"); lock.lock(); try { - position = intPos; + position = newPosition; } finally { lock.unlock(); } @@ -181,15 +180,18 @@ public int read(final ByteBuffer buf) throws IOException { checkOpen(); lock.lock(); try { + if (position > Integer.MAX_VALUE) { + return IOUtils.EOF; + } int wanted = buf.remaining(); - final int possible = size - position; + final int possible = size - (int) position; if (possible <= 0) { return IOUtils.EOF; } if (wanted > possible) { wanted = possible; } - buf.put(data, position, wanted); + buf.put(data, (int) position, wanted); position += wanted; return wanted; } finally { @@ -238,14 +240,14 @@ public byte[] toByteArray() { @Override public SeekableByteChannel truncate(final long newSize) throws ClosedChannelException { checkOpen(); - final int intSize = checkRange(newSize, "truncate()"); + checkRange(newSize, "truncate()"); lock.lock(); try { - if (size > intSize) { - size = intSize; + if (size > newSize) { + size = (int) newSize; } - if (position > intSize) { - position = intSize; + if (position > newSize) { + position = newSize; } } finally { lock.unlock(); @@ -256,21 +258,28 @@ public SeekableByteChannel truncate(final long newSize) throws ClosedChannelExce @Override public int write(final ByteBuffer b) throws IOException { checkOpen(); + if (position > Integer.MAX_VALUE) { + throw new IOException("position > Integer.MAX_VALUE"); + } lock.lock(); try { final int wanted = b.remaining(); - final int possibleWithoutResize = Math.max(0, size - position); - if (wanted > possibleWithoutResize) { - final int newSize = position + wanted; - if (newSize < 0 || newSize > IOUtils.SOFT_MAX_ARRAY_LENGTH) { // overflow - throw new OutOfMemoryError("required array size " + Integer.toUnsignedString(newSize) + " too large"); - } - resize(newSize); + // intPos <= Integer.MAX_VALUE + final int intPos = (int) position; + final long newPosition = position + wanted; + if (newPosition > IOUtils.SOFT_MAX_ARRAY_LENGTH) { + throw new IOException(String.format("Requested array size %,d is too large.", newPosition)); } - b.get(data, position, wanted); - position += wanted; - if (size < position) { - size = position; + if (newPosition > size) { + final int newPositionInt = (int) newPosition; + // Ensure that newPositionInt ≤ data.length + resize(newPositionInt); + size = newPositionInt; + } + b.get(data, intPos, wanted); + position = newPosition; + if (size < intPos) { + size = intPos; } return wanted; } finally { diff --git a/src/test/java/org/apache/commons/io/channels/AbstractSeekableByteChannelTest.java b/src/test/java/org/apache/commons/io/channels/AbstractSeekableByteChannelTest.java index abf96d8f47d..afb2245a49f 100644 --- a/src/test/java/org/apache/commons/io/channels/AbstractSeekableByteChannelTest.java +++ b/src/test/java/org/apache/commons/io/channels/AbstractSeekableByteChannelTest.java @@ -47,7 +47,7 @@ */ abstract class AbstractSeekableByteChannelTest { - private SeekableByteChannel channel; + protected SeekableByteChannel channel; @TempDir protected Path tempDir; @@ -87,6 +87,7 @@ void testCloseMultipleTimes() throws IOException { assertFalse(channel.isOpen()); } + @Test void testConcurrentPositionAndSizeQueries() throws IOException { final byte[] data = "test data".getBytes(); @@ -136,6 +137,20 @@ void testPositionBeyondSize() throws IOException { assertEquals(4, channel.size()); // Size should not change } + @Test + void testPositionBeyondSizeRead() throws IOException { + final ByteBuffer buffer = ByteBuffer.allocate(1); + channel.position(channel.size() + 1); + assertEquals(channel.size() + 1, channel.position()); + assertEquals(-1, channel.read(buffer)); + channel.position(Integer.MAX_VALUE + 1L); + assertEquals(Integer.MAX_VALUE + 1L, channel.position()); + assertEquals(-1, channel.read(buffer)); + assertThrows(IllegalArgumentException.class, () -> channel.position(-1)); + assertThrows(IllegalArgumentException.class, () -> channel.position(Integer.MIN_VALUE)); + assertThrows(IllegalArgumentException.class, () -> channel.position(Long.MIN_VALUE)); + } + @ParameterizedTest @CsvSource({ "0, 0", "5, 5", "10, 10", "100, 100" }) void testPositionInBounds(final long newPosition, final long expectedPosition) throws IOException { @@ -149,6 +164,7 @@ void testPositionInBounds(final long newPosition, final long expectedPosition) t assertEquals(expectedPosition, channel.position()); } + @Test void testPositionNegative() { assertThrows(IllegalArgumentException.class, () -> channel.position(-1)); @@ -292,6 +308,18 @@ void testSizeSameOnOverwrite() throws IOException { assertEquals(11, channel.size()); // Size should not change } + @Test + void testTrucateBeyondSizeReadWrite() throws IOException { + final ByteBuffer buffer = ByteBuffer.allocate(1); + channel.truncate(channel.size() + 1); + assertEquals(-1, channel.read(buffer)); + channel.truncate(Integer.MAX_VALUE + 1L); + assertEquals(-1, channel.read(buffer)); + assertThrows(IllegalArgumentException.class, () -> channel.truncate(-1)); + assertThrows(IllegalArgumentException.class, () -> channel.truncate(Integer.MIN_VALUE)); + assertThrows(IllegalArgumentException.class, () -> channel.truncate(Long.MIN_VALUE)); + } + @Test void testTruncateNegative() { assertThrows(IllegalArgumentException.class, () -> channel.truncate(-1)); diff --git a/src/test/java/org/apache/commons/io/channels/ByteArraySeekableByteChannelCompressTest.java b/src/test/java/org/apache/commons/io/channels/ByteArraySeekableByteChannelCompressTest.java index 66c8e60b91a..60c294f0d65 100644 --- a/src/test/java/org/apache/commons/io/channels/ByteArraySeekableByteChannelCompressTest.java +++ b/src/test/java/org/apache/commons/io/channels/ByteArraySeekableByteChannelCompressTest.java @@ -162,16 +162,39 @@ void testShouldThrowExceptionOnWritingToClosedChannel() { } @Test - void testShouldThrowExceptionWhenSettingIncorrectPosition() { + void testThrowWhenSettingIncorrectPosition() throws IOException { try (ByteArraySeekableByteChannel c = new ByteArraySeekableByteChannel()) { - assertThrows(IllegalArgumentException.class, () -> c.position(Integer.MAX_VALUE + 1L)); + final ByteBuffer buffer = ByteBuffer.allocate(1); + // write + c.write(buffer); + assertEquals(1, c.position()); + // bad pos A + c.position(c.size() + 1); + assertEquals(c.size() + 1, c.position()); + assertEquals(-1, c.read(buffer)); + // bad pos B + c.position(Integer.MAX_VALUE + 1L); + assertEquals(Integer.MAX_VALUE + 1L, c.position()); + assertEquals(-1, c.read(buffer)); + assertThrows(IOException.class, () -> c.write(buffer)); + // negative input is the only illegal input + assertThrows(IllegalArgumentException.class, () -> c.position(-1)); + assertThrows(IllegalArgumentException.class, () -> c.position(Integer.MIN_VALUE)); + assertThrows(IllegalArgumentException.class, () -> c.position(Long.MIN_VALUE)); } } @Test - void testShouldThrowExceptionWhenTruncatingToIncorrectSize() { + void testThrowWhenTruncatingToIncorrectSize() throws IOException { try (ByteArraySeekableByteChannel c = new ByteArraySeekableByteChannel()) { - assertThrows(IllegalArgumentException.class, () -> c.truncate(Integer.MAX_VALUE + 1L)); + final ByteBuffer buffer = ByteBuffer.allocate(1); + c.truncate(c.size() + 1); + assertEquals(-1, c.read(buffer)); + c.truncate(Integer.MAX_VALUE + 1L); + assertEquals(-1, c.read(buffer)); + assertThrows(IllegalArgumentException.class, () -> c.truncate(-1)); + assertThrows(IllegalArgumentException.class, () -> c.truncate(Integer.MIN_VALUE)); + assertThrows(IllegalArgumentException.class, () -> c.truncate(Long.MIN_VALUE)); } } diff --git a/src/test/java/org/apache/commons/io/channels/ByteArraySeekableByteChannelTest.java b/src/test/java/org/apache/commons/io/channels/ByteArraySeekableByteChannelTest.java index 5b29b3f1be0..79f971e4a0c 100644 --- a/src/test/java/org/apache/commons/io/channels/ByteArraySeekableByteChannelTest.java +++ b/src/test/java/org/apache/commons/io/channels/ByteArraySeekableByteChannelTest.java @@ -84,6 +84,22 @@ void testConstructorInvalid() { assertThrows(NullPointerException.class, () -> ByteArraySeekableByteChannel.wrap(null)); } + @Test + void testPositionBeyondSizeReadWrite() throws IOException { + final ByteBuffer buffer = ByteBuffer.allocate(1); + channel.position(channel.size() + 1); + assertEquals(channel.size() + 1, channel.position()); + assertEquals(-1, channel.read(buffer)); + channel.position(Integer.MAX_VALUE + 1L); + assertEquals(Integer.MAX_VALUE + 1L, channel.position()); + assertEquals(-1, channel.read(buffer)); + // ByteArraySeekableByteChannel has a hard boundary at Integer.MAX_VALUE, files don't. + assertThrows(IOException.class, () -> channel.write(buffer)); + assertThrows(IllegalArgumentException.class, () -> channel.position(-1)); + assertThrows(IllegalArgumentException.class, () -> channel.position(Integer.MIN_VALUE)); + assertThrows(IllegalArgumentException.class, () -> channel.position(Long.MIN_VALUE)); + } + @ParameterizedTest @MethodSource void testShouldResizeWhenWritingMoreDataThanCapacity(final byte[] data, final int wanted) throws IOException {