diff --git a/utils/socket-utils/src/main/java17/datadog/common/socket/TunnelingJdkSocket.java b/utils/socket-utils/src/main/java17/datadog/common/socket/TunnelingJdkSocket.java index 6db94aa15ac..6f5267257b6 100644 --- a/utils/socket-utils/src/main/java17/datadog/common/socket/TunnelingJdkSocket.java +++ b/utils/socket-utils/src/main/java17/datadog/common/socket/TunnelingJdkSocket.java @@ -21,7 +21,8 @@ * Subtype UNIX socket for a higher-fidelity impersonation of TCP sockets. This is named "tunneling" * because it assumes the ultimate destination has a hostname and port. * - *
Bsed on {@link TunnelingUnixSocket}; adapted to use the built-in UDS support added in Java 16. + *
Based on {@link TunnelingUnixSocket}; adapted to use the built-in UDS support added in Java + * 16. */ final class TunnelingJdkSocket extends Socket { private final SocketAddress unixSocketAddress; @@ -34,6 +35,11 @@ final class TunnelingJdkSocket extends Socket { private boolean shutOut; private boolean closed; + protected static final int DEFAULT_BUFFER_SIZE = 8192; + // Indicate that the buffer size is not set by initializing to -1 + private int sendBufferSize = -1; + private int receiveBufferSize = -1; + TunnelingJdkSocket(final Path path) { this.unixSocketAddress = UnixDomainSocketAddress.of(path); } @@ -114,6 +120,70 @@ public SocketChannel getChannel() { return unixSocketChannel; } + @Override + public void setSendBufferSize(int size) throws SocketException { + if (isClosed()) { + throw new SocketException("Socket is closed"); + } + if (size < 0) { + throw new IllegalArgumentException("Invalid send buffer size"); + } + try { + unixSocketChannel.setOption(java.net.StandardSocketOptions.SO_SNDBUF, size); + sendBufferSize = size; + } catch (IOException e) { + throw new SocketException("Failed to set send buffer size"); + } + } + + @Override + public int getSendBufferSize() throws SocketException { + if (isClosed()) { + throw new SocketException("Socket is closed"); + } + if (sendBufferSize == -1) { + return DEFAULT_BUFFER_SIZE; + } + return sendBufferSize; + } + + @Override + public void setReceiveBufferSize(int size) throws SocketException { + if (isClosed()) { + throw new SocketException("Socket is closed"); + } + if (size < 0) { + throw new IllegalArgumentException("Invalid receive buffer size"); + } + try { + unixSocketChannel.setOption(java.net.StandardSocketOptions.SO_RCVBUF, size); + receiveBufferSize = size; + } catch (IOException e) { + throw new SocketException("Failed to set receive buffer size"); + } + } + + @Override + public int getReceiveBufferSize() throws SocketException { + if (isClosed()) { + throw new SocketException("Socket is closed"); + } + if (receiveBufferSize == -1) { + return DEFAULT_BUFFER_SIZE; + } + return receiveBufferSize; + } + + public int getStreamBufferSize() throws SocketException { + if (isClosed()) { + throw new SocketException("Socket is closed"); + } + if (sendBufferSize == -1 && receiveBufferSize == -1) { + return DEFAULT_BUFFER_SIZE; + } + return Math.max(sendBufferSize, receiveBufferSize); + } + @Override public InputStream getInputStream() throws IOException { if (isClosed()) { @@ -127,7 +197,7 @@ public InputStream getInputStream() throws IOException { } return new InputStream() { - private final ByteBuffer buffer = ByteBuffer.allocate(8192); + private final ByteBuffer buffer = ByteBuffer.allocate(getStreamBufferSize()); private final Selector selector = Selector.open(); { diff --git a/utils/socket-utils/src/test/java/datadog/common/socket/TunnelingJdkSocketTest.java b/utils/socket-utils/src/test/java/datadog/common/socket/TunnelingJdkSocketTest.java index 05cf96e94d8..74cca0d4bd1 100644 --- a/utils/socket-utils/src/test/java/datadog/common/socket/TunnelingJdkSocketTest.java +++ b/utils/socket-utils/src/test/java/datadog/common/socket/TunnelingJdkSocketTest.java @@ -1,11 +1,13 @@ package datadog.common.socket; +import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; -import static org.junit.jupiter.api.Assertions.fail; import datadog.trace.api.Config; import java.io.IOException; +import java.io.InputStream; import java.net.InetSocketAddress; +import java.net.SocketException; import java.net.StandardProtocolFamily; import java.net.UnixDomainSocketAddress; import java.nio.channels.ServerSocketChannel; @@ -28,32 +30,92 @@ public void testTimeout() throws Exception { return; } - int testTimeout = 3000; Path socketPath = getSocketPath(); UnixDomainSocketAddress socketAddress = UnixDomainSocketAddress.of(socketPath); startServer(socketAddress); TunnelingJdkSocket clientSocket = createClient(socketPath); + InputStream inputStream = clientSocket.getInputStream(); - // Test that the socket unblocks when timeout is set to >0 - clientSocket.setSoTimeout(1000); - assertTimeoutPreemptively( - Duration.ofMillis(testTimeout), () -> clientSocket.getInputStream().read()); + int testTimeout = 1000; + clientSocket.setSoTimeout(testTimeout); + assertEquals(testTimeout, clientSocket.getSoTimeout()); - // Test that the socket blocks indefinitely when timeout is set to 0, per + long startTime = System.currentTimeMillis(); + int readResult = inputStream.read(); + long endTime = System.currentTimeMillis(); + long readDuration = endTime - startTime; + int timeVariance = 100; + assertTrue(readDuration >= testTimeout && readDuration <= testTimeout + timeVariance); + assertEquals(0, readResult); + + int newTimeout = testTimeout / 2; + clientSocket.setSoTimeout(newTimeout); + assertEquals(newTimeout, clientSocket.getSoTimeout()); + assertTimeoutPreemptively(Duration.ofMillis(testTimeout), () -> inputStream.read()); + + // The socket should block indefinitely when timeout is set to 0, per // https://docs.oracle.com/en/java/javase/16/docs/api//java.base/java/net/Socket.html#setSoTimeout(int). - clientSocket.setSoTimeout(0); - boolean infiniteTimeOut = false; + int infiniteTimeout = 0; + clientSocket.setSoTimeout(infiniteTimeout); + assertEquals(infiniteTimeout, clientSocket.getSoTimeout()); try { - assertTimeoutPreemptively( - Duration.ofMillis(testTimeout), () -> clientSocket.getInputStream().read()); + assertTimeoutPreemptively(Duration.ofMillis(testTimeout), () -> inputStream.read()); + fail("Read should block indefinitely with infinite timeout"); } catch (AssertionError e) { - infiniteTimeOut = true; + // Expected } - if (!infiniteTimeOut) { - fail("Test failed: Expected infinite blocking when timeout is set to 0."); + + int invalidTimeout = -1; + assertThrows(IllegalArgumentException.class, () -> clientSocket.setSoTimeout(invalidTimeout)); + + clientSocket.close(); + assertThrows(SocketException.class, () -> clientSocket.setSoTimeout(testTimeout)); + assertThrows(SocketException.class, () -> clientSocket.getSoTimeout()); + + isServerRunning.set(false); + } + + @Test + public void testBufferSizes() throws Exception { + if (!Config.get().isJdkSocketEnabled()) { + System.out.println( + "TunnelingJdkSocket usage is disabled. Enable it by setting the property 'JDK_SOCKET_ENABLED' to 'true'."); + return; } + Path socketPath = getSocketPath(); + UnixDomainSocketAddress socketAddress = UnixDomainSocketAddress.of(socketPath); + startServer(socketAddress); + TunnelingJdkSocket clientSocket = createClient(socketPath); + + assertEquals(TunnelingJdkSocket.DEFAULT_BUFFER_SIZE, clientSocket.getSendBufferSize()); + assertEquals(TunnelingJdkSocket.DEFAULT_BUFFER_SIZE, clientSocket.getReceiveBufferSize()); + assertEquals(TunnelingJdkSocket.DEFAULT_BUFFER_SIZE, clientSocket.getStreamBufferSize()); + + int newBufferSize = TunnelingJdkSocket.DEFAULT_BUFFER_SIZE / 2; + clientSocket.setSendBufferSize(newBufferSize); + clientSocket.setReceiveBufferSize(newBufferSize / 2); + assertEquals(newBufferSize, clientSocket.getSendBufferSize()); + assertEquals(newBufferSize / 2, clientSocket.getReceiveBufferSize()); + assertEquals(newBufferSize, clientSocket.getStreamBufferSize()); + + int invalidBufferSize = -1; + assertThrows( + IllegalArgumentException.class, () -> clientSocket.setSendBufferSize(invalidBufferSize)); + assertThrows( + IllegalArgumentException.class, () -> clientSocket.setReceiveBufferSize(invalidBufferSize)); + clientSocket.close(); + assertThrows( + SocketException.class, + () -> clientSocket.setSendBufferSize(TunnelingJdkSocket.DEFAULT_BUFFER_SIZE)); + assertThrows( + SocketException.class, + () -> clientSocket.setReceiveBufferSize(TunnelingJdkSocket.DEFAULT_BUFFER_SIZE)); + assertThrows(SocketException.class, () -> clientSocket.getSendBufferSize()); + assertThrows(SocketException.class, () -> clientSocket.getReceiveBufferSize()); + assertThrows(SocketException.class, () -> clientSocket.getStreamBufferSize()); + isServerRunning.set(false); }